aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-08-24 16:13:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 16:19:29 -0700
commitf6c3c9733ed39f14ee3c32bc51ec62315b48ad31 (patch)
tree5549e540017ec2cf493efb7df71d059d774b3217 /tensorflow
parent6d7261ef22835dc51fb157bdb1db349fd26d8f86 (diff)
Upgrade Keras applications and Keras preprocessing.
PiperOrigin-RevId: 210174523
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/docs_src/install/install_sources.md6
-rw-r--r--tensorflow/docs_src/install/install_sources_windows.md4
-rw-r--r--tensorflow/python/keras/applications/__init__.py51
-rw-r--r--tensorflow/python/keras/applications/applications_test.py8
-rw-r--r--tensorflow/python/keras/applications/densenet.py47
-rw-r--r--tensorflow/python/keras/applications/imagenet_utils.py33
-rw-r--r--tensorflow/python/keras/applications/inception_resnet_v2.py26
-rw-r--r--tensorflow/python/keras/applications/inception_v3.py25
-rw-r--r--tensorflow/python/keras/applications/mobilenet.py25
-rw-r--r--tensorflow/python/keras/applications/mobilenet_v2.py24
-rw-r--r--tensorflow/python/keras/applications/nasnet.py35
-rw-r--r--tensorflow/python/keras/applications/resnet50.py24
-rw-r--r--tensorflow/python/keras/applications/vgg16.py24
-rw-r--r--tensorflow/python/keras/applications/vgg19.py24
-rw-r--r--tensorflow/python/keras/applications/xception.py25
-rw-r--r--tensorflow/python/keras/preprocessing/__init__.py2
-rw-r--r--tensorflow/python/keras/preprocessing/image.py492
-rw-r--r--tensorflow/python/keras/preprocessing/sequence.py63
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl1
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl1
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cmake4
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh8
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh4
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh4
-rw-r--r--tensorflow/tools/docker/Dockerfile4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn74
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl-horovod4
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl-horovod4
-rw-r--r--tensorflow/tools/pip_package/setup.py4
34 files changed, 851 insertions, 149 deletions
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index e8e13142e9..44ea18fa7b 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -180,9 +180,9 @@ If you follow these instructions, you will not need to disable SIP.
After installing pip, invoke the following commands:
-<pre> $ <b>sudo pip install six numpy wheel mock h5py</b>
- $ <b>sudo pip install keras_applications==1.0.4 --no-deps</b>
- $ <b>sudo pip install keras_preprocessing==1.0.2 --no-deps</b>
+<pre> $ <b>pip install six numpy wheel mock h5py</b>
+ $ <b>pip install keras_applications==1.0.5 --no-deps</b>
+ $ <b>pip install keras_preprocessing==1.0.3 --no-deps</b>
</pre>
Note: These are just the minimum requirements to _build_ tensorflow. Installing
diff --git a/tensorflow/docs_src/install/install_sources_windows.md b/tensorflow/docs_src/install/install_sources_windows.md
index a1da122317..40dce106d6 100644
--- a/tensorflow/docs_src/install/install_sources_windows.md
+++ b/tensorflow/docs_src/install/install_sources_windows.md
@@ -94,8 +94,8 @@ Assume you already have `pip3` in `%PATH%`, issue the following command:
<pre>
C:\> <b>pip3 install six numpy wheel</b>
-C:\> <b>pip3 install keras_applications==1.0.4 --no-deps</b>
-C:\> <b>pip3 install keras_preprocessing==1.0.2 --no-deps</b>
+C:\> <b>pip3 install keras_applications==1.0.5 --no-deps</b>
+C:\> <b>pip3 install keras_preprocessing==1.0.3 --no-deps</b>
</pre>
<a name="InstallCUDA"></a>
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index cd9462d6b5..a8b6d55e41 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Keras Applications are canned architectures with pre-trained weights."""
# pylint: disable=g-import-not-at-top
+# pylint: disable=g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -25,13 +26,49 @@ from tensorflow.python.keras import engine
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import utils
+from tensorflow.python.util import tf_inspect
+
+# `get_submodules_from_kwargs` has been introduced in 1.0.5, but we would
+# like to be able to handle prior versions. Note that prior to 1.0.5,
+# `keras_applications` did not expose a `__version__` attribute.
+if not hasattr(keras_applications, 'get_submodules_from_kwargs'):
+
+ if 'engine' in tf_inspect.getfullargspec(
+ keras_applications.set_keras_submodules)[0]:
+ keras_applications.set_keras_submodules(
+ backend=backend,
+ layers=layers,
+ models=models,
+ utils=utils,
+ engine=engine)
+ else:
+ keras_applications.set_keras_submodules(
+ backend=backend,
+ layers=layers,
+ models=models,
+ utils=utils)
+
+
+def keras_modules_injection(base_fun):
+ """Decorator injecting tf.keras replacements for Keras modules.
+
+ Arguments:
+ base_fun: Application function to decorate (e.g. `MobileNet`).
+
+ Returns:
+ Decorated function that injects keyword argument for the tf.keras
+ modules required by the Applications.
+ """
+
+ def wrapper(*args, **kwargs):
+ if hasattr(keras_applications, 'get_submodules_from_kwargs'):
+ kwargs['backend'] = backend
+ kwargs['layers'] = layers
+ kwargs['models'] = models
+ kwargs['utils'] = utils
+ return base_fun(*args, **kwargs)
+ return wrapper
-keras_applications.set_keras_submodules(
- backend=backend,
- engine=engine,
- layers=layers,
- models=models,
- utils=utils)
from tensorflow.python.keras.applications.densenet import DenseNet121
from tensorflow.python.keras.applications.densenet import DenseNet169
@@ -39,7 +76,7 @@ from tensorflow.python.keras.applications.densenet import DenseNet201
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras.applications.mobilenet import MobileNet
-# TODO(fchollet): enable MobileNetV2 in next version.
+from tensorflow.python.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.python.keras.applications.nasnet import NASNetLarge
from tensorflow.python.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras.applications.resnet50 import ResNet50
diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py
index ef3198a937..b15ca5990a 100644
--- a/tensorflow/python/keras/applications/applications_test.py
+++ b/tensorflow/python/keras/applications/applications_test.py
@@ -32,7 +32,8 @@ MODEL_LIST = [
(applications.InceptionV3, 2048),
(applications.InceptionResNetV2, 1536),
(applications.MobileNet, 1024),
- # TODO(fchollet): enable MobileNetV2 in next version.
+ # TODO(fchollet): enable MobileNetV2 tests when a new TensorFlow test image
+ # is released with keras_applications upgraded to 1.0.5 or above.
(applications.DenseNet121, 1024),
(applications.DenseNet169, 1664),
(applications.DenseNet201, 1920),
@@ -44,11 +45,6 @@ MODEL_LIST = [
class ApplicationsTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(*MODEL_LIST)
- def test_classification_model(self, model_fn, _):
- model = model_fn(classes=1000, weights=None)
- self.assertEqual(model.output_shape[-1], 1000)
-
- @parameterized.parameters(*MODEL_LIST)
def test_feature_extration_model(self, model_fn, output_dim):
model = model_fn(include_top=False, weights=None)
self.assertEqual(model.output_shape, (None, None, None, output_dim))
diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py
index fbdcc66d2d..172848bbdb 100644
--- a/tensorflow/python/keras/applications/densenet.py
+++ b/tensorflow/python/keras/applications/densenet.py
@@ -20,18 +20,39 @@ from __future__ import division
from __future__ import print_function
from keras_applications import densenet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-DenseNet121 = densenet.DenseNet121
-DenseNet169 = densenet.DenseNet169
-DenseNet201 = densenet.DenseNet201
-decode_predictions = densenet.decode_predictions
-preprocess_input = densenet.preprocess_input
-
-tf_export('keras.applications.densenet.DenseNet121',
- 'keras.applications.DenseNet121')(DenseNet121)
-tf_export('keras.applications.densenet.DenseNet169',
- 'keras.applications.DenseNet169')(DenseNet169)
-tf_export('keras.applications.densenet.DenseNet201',
- 'keras.applications.DenseNet201')(DenseNet201)
-tf_export('keras.applications.densenet.preprocess_input')(preprocess_input)
+
+@tf_export('keras.applications.densenet.DenseNet121',
+ 'keras.applications.DenseNet121')
+@keras_modules_injection
+def DenseNet121(*args, **kwargs):
+ return densenet.DenseNet121(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.DenseNet169',
+ 'keras.applications.DenseNet169')
+@keras_modules_injection
+def DenseNet169(*args, **kwargs):
+ return densenet.DenseNet169(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.DenseNet201',
+ 'keras.applications.DenseNet201')
+@keras_modules_injection
+def DenseNet201(*args, **kwargs):
+ return densenet.DenseNet201(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return densenet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return densenet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/imagenet_utils.py b/tensorflow/python/keras/applications/imagenet_utils.py
index 70f8f6fb32..c25b5c2bdd 100644
--- a/tensorflow/python/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/applications/imagenet_utils.py
@@ -19,27 +19,18 @@ from __future__ import division
from __future__ import print_function
from keras_applications import imagenet_utils
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-decode_predictions = imagenet_utils.decode_predictions
-preprocess_input = imagenet_utils.preprocess_input
-tf_export(
- 'keras.applications.imagenet_utils.decode_predictions',
- 'keras.applications.densenet.decode_predictions',
- 'keras.applications.inception_resnet_v2.decode_predictions',
- 'keras.applications.inception_v3.decode_predictions',
- 'keras.applications.mobilenet.decode_predictions',
- 'keras.applications.mobilenet_v2.decode_predictions',
- 'keras.applications.nasnet.decode_predictions',
- 'keras.applications.resnet50.decode_predictions',
- 'keras.applications.vgg16.decode_predictions',
- 'keras.applications.vgg19.decode_predictions',
- 'keras.applications.xception.decode_predictions',
-)(decode_predictions)
-tf_export(
- 'keras.applications.imagenet_utils.preprocess_input',
- 'keras.applications.resnet50.preprocess_input',
- 'keras.applications.vgg16.preprocess_input',
- 'keras.applications.vgg19.preprocess_input',
-)(preprocess_input)
+@tf_export('keras.applications.imagenet_utils.preprocess_input')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return imagenet_utils.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.imagenet_utils.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return imagenet_utils.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index 63debb4e0d..0b9ef371fa 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -20,13 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import inception_resnet_v2
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-InceptionResNetV2 = inception_resnet_v2.InceptionResNetV2
-decode_predictions = inception_resnet_v2.decode_predictions
-preprocess_input = inception_resnet_v2.preprocess_input
-tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
- 'keras.applications.InceptionResNetV2')(InceptionResNetV2)
-tf_export(
- 'keras.applications.inception_resnet_v2.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
+ 'keras.applications.InceptionResNetV2')
+@keras_modules_injection
+def InceptionResNetV2(*args, **kwargs):
+ return inception_resnet_v2.InceptionResNetV2(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_resnet_v2.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return inception_resnet_v2.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_resnet_v2.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return inception_resnet_v2.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py
index 87534086c8..ab76826e17 100644
--- a/tensorflow/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/applications/inception_v3.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import inception_v3
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-InceptionV3 = inception_v3.InceptionV3
-decode_predictions = inception_v3.decode_predictions
-preprocess_input = inception_v3.preprocess_input
-tf_export('keras.applications.inception_v3.InceptionV3',
- 'keras.applications.InceptionV3')(InceptionV3)
-tf_export('keras.applications.inception_v3.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.inception_v3.InceptionV3',
+ 'keras.applications.InceptionV3')
+@keras_modules_injection
+def InceptionV3(*args, **kwargs):
+ return inception_v3.InceptionV3(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_v3.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return inception_v3.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_v3.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return inception_v3.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index 3528f027b3..1f71a5ae99 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import mobilenet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-MobileNet = mobilenet.MobileNet
-decode_predictions = mobilenet.decode_predictions
-preprocess_input = mobilenet.preprocess_input
-tf_export('keras.applications.mobilenet.MobileNet',
- 'keras.applications.MobileNet')(MobileNet)
-tf_export('keras.applications.mobilenet.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.mobilenet.MobileNet',
+ 'keras.applications.MobileNet')
+@keras_modules_injection
+def MobileNet(*args, **kwargs):
+ return mobilenet.MobileNet(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return mobilenet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return mobilenet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py
index 9194c3ee14..52ac5959ad 100644
--- a/tensorflow/python/keras/applications/mobilenet_v2.py
+++ b/tensorflow/python/keras/applications/mobilenet_v2.py
@@ -19,4 +19,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(fchollet): export MobileNetV2 as part of the public API in next version.
+from keras_applications import mobilenet_v2
+
+from tensorflow.python.keras.applications import keras_modules_injection
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('keras.applications.mobilenet_v2.MobileNetV2',
+ 'keras.applications.MobileNetV2')
+@keras_modules_injection
+def MobileNetV2(*args, **kwargs):
+ return mobilenet_v2.MobileNetV2(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet_v2.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return mobilenet_v2.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet_v2.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return mobilenet_v2.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py
index 26ff5db53f..44fc329d57 100644
--- a/tensorflow/python/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/applications/nasnet.py
@@ -20,15 +20,32 @@ from __future__ import division
from __future__ import print_function
from keras_applications import nasnet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-NASNetMobile = nasnet.NASNetMobile
-NASNetLarge = nasnet.NASNetLarge
-decode_predictions = nasnet.decode_predictions
-preprocess_input = nasnet.preprocess_input
-tf_export('keras.applications.nasnet.NASNetMobile',
- 'keras.applications.NASNetMobile')(NASNetMobile)
-tf_export('keras.applications.nasnet.NASNetLarge',
- 'keras.applications.NASNetLarge')(NASNetLarge)
-tf_export('keras.applications.nasnet.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.nasnet.NASNetMobile',
+ 'keras.applications.NASNetMobile')
+@keras_modules_injection
+def NASNetMobile(*args, **kwargs):
+ return nasnet.NASNetMobile(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.NASNetLarge',
+ 'keras.applications.NASNetLarge')
+@keras_modules_injection
+def NASNetLarge(*args, **kwargs):
+ return nasnet.NASNetLarge(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return nasnet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return nasnet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py
index 4d804a3c44..80d3f9044f 100644
--- a/tensorflow/python/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/applications/resnet50.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import resnet50
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-ResNet50 = resnet50.ResNet50
-decode_predictions = resnet50.decode_predictions
-preprocess_input = resnet50.preprocess_input
-tf_export('keras.applications.resnet50.ResNet50',
- 'keras.applications.ResNet50')(ResNet50)
+@tf_export('keras.applications.resnet50.ResNet50',
+ 'keras.applications.ResNet50')
+@keras_modules_injection
+def ResNet50(*args, **kwargs):
+ return resnet50.ResNet50(*args, **kwargs)
+
+
+@tf_export('keras.applications.resnet50.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return resnet50.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.resnet50.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return resnet50.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py
index c420d9b81e..8557d26931 100644
--- a/tensorflow/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/applications/vgg16.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import vgg16
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-VGG16 = vgg16.VGG16
-decode_predictions = vgg16.decode_predictions
-preprocess_input = vgg16.preprocess_input
-tf_export('keras.applications.vgg16.VGG16',
- 'keras.applications.VGG16')(VGG16)
+@tf_export('keras.applications.vgg16.VGG16',
+ 'keras.applications.VGG16')
+@keras_modules_injection
+def VGG16(*args, **kwargs):
+ return vgg16.VGG16(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg16.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return vgg16.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg16.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return vgg16.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py
index 73d3d1d1c3..8fc04413a0 100644
--- a/tensorflow/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/applications/vgg19.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import vgg19
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-VGG19 = vgg19.VGG19
-decode_predictions = vgg19.decode_predictions
-preprocess_input = vgg19.preprocess_input
-tf_export('keras.applications.vgg19.VGG19',
- 'keras.applications.VGG19')(VGG19)
+@tf_export('keras.applications.vgg19.VGG19',
+ 'keras.applications.VGG19')
+@keras_modules_injection
+def VGG19(*args, **kwargs):
+ return vgg19.VGG19(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg19.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return vgg19.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg19.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return vgg19.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py
index 5b221ac8e0..960e6dec69 100644
--- a/tensorflow/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/applications/xception.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import xception
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-Xception = xception.Xception
-decode_predictions = xception.decode_predictions
-preprocess_input = xception.preprocess_input
-tf_export('keras.applications.xception.Xception',
- 'keras.applications.Xception')(Xception)
-tf_export('keras.applications.xception.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.xception.Xception',
+ 'keras.applications.Xception')
+@keras_modules_injection
+def Xception(*args, **kwargs):
+ return xception.Xception(*args, **kwargs)
+
+
+@tf_export('keras.applications.xception.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return xception.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.xception.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return xception.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py
index 2f08f88600..0860eed3cf 100644
--- a/tensorflow/python/keras/preprocessing/__init__.py
+++ b/tensorflow/python/keras/preprocessing/__init__.py
@@ -23,6 +23,8 @@ import keras_preprocessing
from tensorflow.python.keras import backend
from tensorflow.python.keras import utils
+# This exists for compatibility with prior version of keras_preprocessing.
+# TODO(fchollet): remove in the future.
keras_preprocessing.set_keras_submodules(backend=backend, utils=utils)
from tensorflow.python.keras.preprocessing import image
diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py
index ba227385ef..e33993950d 100644
--- a/tensorflow/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/preprocessing/image.py
@@ -27,6 +27,9 @@ try:
except ImportError:
pass
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import utils
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
random_rotation = image.random_rotation
@@ -38,14 +41,482 @@ random_channel_shift = image.random_channel_shift
apply_brightness_shift = image.apply_brightness_shift
random_brightness = image.random_brightness
apply_affine_transform = image.apply_affine_transform
-array_to_img = image.array_to_img
-img_to_array = image.img_to_array
-save_img = image.save_img
load_img = image.load_img
-ImageDataGenerator = image.ImageDataGenerator
-Iterator = image.Iterator
-NumpyArrayIterator = image.NumpyArrayIterator
-DirectoryIterator = image.DirectoryIterator
+
+
+@tf_export('keras.preprocessing.image.array_to_img')
+def array_to_img(x, data_format=None, scale=True, dtype=None):
+ """Converts a 3D Numpy array to a PIL Image instance.
+
+ Arguments:
+ x: Input Numpy array.
+ data_format: Image data format.
+ either "channels_first" or "channels_last".
+ scale: Whether to rescale image values
+ to be within `[0, 255]`.
+ dtype: Dtype to use.
+
+ Returns:
+ A PIL Image instance.
+
+ Raises:
+ ImportError: if PIL is not available.
+ ValueError: if invalid `x` or `data_format` is passed.
+ """
+
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.img_to_array')
+def img_to_array(img, data_format=None, dtype=None):
+ """Converts a PIL Image instance to a Numpy array.
+
+ Arguments:
+ img: PIL Image instance.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ dtype: Dtype to use for the returned array.
+
+ Returns:
+ A 3D Numpy array.
+
+ Raises:
+ ValueError: if invalid `img` or `data_format` is passed.
+ """
+
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ return image.img_to_array(img, data_format=data_format, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.save_img')
+def save_img(path,
+ x,
+ data_format=None,
+ file_format=None,
+ scale=True,
+ **kwargs):
+ """Saves an image stored as a Numpy array to a path or file object.
+
+ Arguments:
+ path: Path or file object.
+ x: Numpy array.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ file_format: Optional file format override. If omitted, the
+ format to use is determined from the filename extension.
+ If a file object was used instead of a filename, this
+ parameter should always be used.
+ scale: Whether to rescale image values to be within `[0, 255]`.
+ **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
+ """
+ if data_format is None:
+ data_format = backend.image_data_format()
+ image.save_img(path,
+ x,
+ data_format=data_format,
+ file_format=file_format,
+ scale=scale, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.Iterator')
+class Iterator(image.Iterator, utils.Sequence):
+ pass
+
+
+@tf_export('keras.preprocessing.image.DirectoryIterator')
+class DirectoryIterator(image.DirectoryIterator, Iterator):
+ """Iterator capable of reading images from a directory on disk.
+
+ Arguments:
+ directory: Path to the directory to read images from.
+ Each subdirectory in this directory will be
+ considered to contain images from one class,
+ or alternatively you could specify class subdirectories
+ via the `classes` argument.
+ image_data_generator: Instance of `ImageDataGenerator`
+ to use for random transformations and normalization.
+ target_size: tuple of integers, dimensions to resize input images to.
+ color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
+ Color mode to read images.
+ classes: Optional list of strings, names of subdirectories
+ containing images from each class (e.g. `["dogs", "cats"]`).
+ It will be computed automatically if not set.
+ class_mode: Mode for yielding the targets:
+ `"binary"`: binary targets (if there are only two classes),
+ `"categorical"`: categorical targets,
+ `"sparse"`: integer targets,
+ `"input"`: targets are images identical to input images (mainly
+ used to work with autoencoders),
+ `None`: no targets get yielded (only input images are yielded).
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures
+ being yielded, in a viewable format. This is useful
+ for visualizing the random transformations being
+ applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample
+ images (if `save_to_dir` is set).
+ save_format: Format to use for saving sample images
+ (if `save_to_dir` is set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ interpolation: Interpolation method used to resample the image if the
+ target size is different from that of the loaded image.
+ Supported methods are "nearest", "bilinear", and "bicubic".
+ If PIL version 1.1.3 or newer is installed, "lanczos" is also
+ supported. If PIL version 3.4.0 or newer is installed, "box" and
+ "hamming" are also supported. By default, "nearest" is used.
+ dtype: Dtype to use for generated arrays.
+ """
+
+ def __init__(self, directory, image_data_generator,
+ target_size=(256, 256),
+ color_mode='rgb',
+ classes=None,
+ class_mode='categorical',
+ batch_size=32,
+ shuffle=True,
+ seed=None,
+ data_format=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ follow_links=False,
+ subset=None,
+ interpolation='nearest',
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.ImageDataGenerator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(DirectoryIterator, self).__init__(
+ directory, image_data_generator,
+ target_size=target_size,
+ color_mode=color_mode,
+ classes=classes,
+ class_mode=class_mode,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ follow_links=follow_links,
+ subset=subset,
+ interpolation=interpolation,
+ **kwargs)
+
+
+@tf_export('keras.preprocessing.image.NumpyArrayIterator')
+class NumpyArrayIterator(image.NumpyArrayIterator, Iterator):
+ """Iterator yielding data from a Numpy array.
+
+ Arguments:
+ x: Numpy array of input data or tuple.
+ If tuple, the second elements is either
+ another numpy array or a list of numpy arrays,
+ each of which gets passed
+ through as an output without any modifications.
+ y: Numpy array of targets data.
+ image_data_generator: Instance of `ImageDataGenerator`
+ to use for random transformations and normalization.
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ sample_weight: Numpy array of sample weights.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures
+ being yielded, in a viewable format. This is useful
+ for visualizing the random transformations being
+ applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample
+ images (if `save_to_dir` is set).
+ save_format: Format to use for saving sample images
+ (if `save_to_dir` is set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ dtype: Dtype to use for the generated arrays.
+ """
+
+ def __init__(self, x, y, image_data_generator,
+ batch_size=32,
+ shuffle=False,
+ sample_weight=None,
+ seed=None,
+ data_format=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ subset=None,
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.NumpyArrayIterator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(NumpyArrayIterator, self).__init__(
+ x, y, image_data_generator,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ sample_weight=sample_weight,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ subset=subset,
+ **kwargs)
+
+
+@tf_export('keras.preprocessing.image.ImageDataGenerator')
+class ImageDataGenerator(image.ImageDataGenerator):
+ """Generate batches of tensor image data with real-time data augmentation.
+
+ The data will be looped over (in batches).
+
+ Arguments:
+ featurewise_center: Boolean.
+ Set input mean to 0 over the dataset, feature-wise.
+ samplewise_center: Boolean. Set each sample mean to 0.
+ featurewise_std_normalization: Boolean.
+ Divide inputs by std of the dataset, feature-wise.
+ samplewise_std_normalization: Boolean. Divide each input by its std.
+ zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
+ zca_whitening: Boolean. Apply ZCA whitening.
+ rotation_range: Int. Degree range for random rotations.
+ width_shift_range: Float, 1-D array-like or int
+ - float: fraction of total width, if < 1, or pixels if >= 1.
+ - 1-D array-like: random elements from the array.
+ - int: integer number of pixels from interval
+ `(-width_shift_range, +width_shift_range)`
+ - With `width_shift_range=2` possible values
+ are integers `[-1, 0, +1]`,
+ same as with `width_shift_range=[-1, 0, +1]`,
+ while with `width_shift_range=1.0` possible values are floats
+ in the interval [-1.0, +1.0).
+ height_shift_range: Float, 1-D array-like or int
+ - float: fraction of total height, if < 1, or pixels if >= 1.
+ - 1-D array-like: random elements from the array.
+ - int: integer number of pixels from interval
+ `(-height_shift_range, +height_shift_range)`
+ - With `height_shift_range=2` possible values
+ are integers `[-1, 0, +1]`,
+ same as with `height_shift_range=[-1, 0, +1]`,
+ while with `height_shift_range=1.0` possible values are floats
+ in the interval [-1.0, +1.0).
+ brightness_range: Tuple or list of two floats. Range for picking
+ a brightness shift value from.
+ shear_range: Float. Shear Intensity
+ (Shear angle in counter-clockwise direction in degrees)
+ zoom_range: Float or [lower, upper]. Range for random zoom.
+ If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
+ channel_shift_range: Float. Range for random channel shifts.
+ fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
+ Default is 'nearest'.
+ Points outside the boundaries of the input are filled
+ according to the given mode:
+ - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
+ - 'nearest': aaaaaaaa|abcd|dddddddd
+ - 'reflect': abcddcba|abcd|dcbaabcd
+ - 'wrap': abcdabcd|abcd|abcdabcd
+ cval: Float or Int.
+ Value used for points outside the boundaries
+ when `fill_mode = "constant"`.
+ horizontal_flip: Boolean. Randomly flip inputs horizontally.
+ vertical_flip: Boolean. Randomly flip inputs vertically.
+ rescale: rescaling factor. Defaults to None.
+ If None or 0, no rescaling is applied,
+ otherwise we multiply the data by the value provided
+ (after applying all other transformations).
+ preprocessing_function: function that will be implied on each input.
+ The function will run after the image is resized and augmented.
+ The function should take one argument:
+ one image (Numpy tensor with rank 3),
+ and should output a Numpy tensor with the same shape.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ "channels_last" mode means that the images should have shape
+ `(samples, height, width, channels)`,
+ "channels_first" mode means that the images should have shape
+ `(samples, channels, height, width)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
+ validation_split: Float. Fraction of images reserved for validation
+ (strictly between 0 and 1).
+ dtype: Dtype to use for the generated arrays.
+
+ Examples:
+
+ Example of using `.flow(x, y)`:
+
+ ```python
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+ y_train = np_utils.to_categorical(y_train, num_classes)
+ y_test = np_utils.to_categorical(y_test, num_classes)
+ datagen = ImageDataGenerator(
+ featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=20,
+ width_shift_range=0.2,
+ height_shift_range=0.2,
+ horizontal_flip=True)
+ # compute quantities required for featurewise normalization
+ # (std, mean, and principal components if ZCA whitening is applied)
+ datagen.fit(x_train)
+ # fits the model on batches with real-time data augmentation:
+ model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
+ steps_per_epoch=len(x_train) / 32, epochs=epochs)
+ # here's a more "manual" example
+ for e in range(epochs):
+ print('Epoch', e)
+ batches = 0
+ for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
+ model.fit(x_batch, y_batch)
+ batches += 1
+ if batches >= len(x_train) / 32:
+ # we need to break the loop by hand because
+ # the generator loops indefinitely
+ break
+ ```
+
+ Example of using `.flow_from_directory(directory)`:
+
+ ```python
+ train_datagen = ImageDataGenerator(
+ rescale=1./255,
+ shear_range=0.2,
+ zoom_range=0.2,
+ horizontal_flip=True)
+ test_datagen = ImageDataGenerator(rescale=1./255)
+ train_generator = train_datagen.flow_from_directory(
+ 'data/train',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ validation_generator = test_datagen.flow_from_directory(
+ 'data/validation',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50,
+ validation_data=validation_generator,
+ validation_steps=800)
+ ```
+
+ Example of transforming images and masks together.
+
+ ```python
+ # we create two instances with the same arguments
+ data_gen_args = dict(featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=90,
+ width_shift_range=0.1,
+ height_shift_range=0.1,
+ zoom_range=0.2)
+ image_datagen = ImageDataGenerator(**data_gen_args)
+ mask_datagen = ImageDataGenerator(**data_gen_args)
+ # Provide the same seed and keyword arguments to the fit and flow methods
+ seed = 1
+ image_datagen.fit(images, augment=True, seed=seed)
+ mask_datagen.fit(masks, augment=True, seed=seed)
+ image_generator = image_datagen.flow_from_directory(
+ 'data/images',
+ class_mode=None,
+ seed=seed)
+ mask_generator = mask_datagen.flow_from_directory(
+ 'data/masks',
+ class_mode=None,
+ seed=seed)
+ # combine generators into one which yields image and masks
+ train_generator = zip(image_generator, mask_generator)
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50)
+ ```
+ """
+
+ def __init__(self,
+ featurewise_center=False,
+ samplewise_center=False,
+ featurewise_std_normalization=False,
+ samplewise_std_normalization=False,
+ zca_whitening=False,
+ zca_epsilon=1e-6,
+ rotation_range=0,
+ width_shift_range=0.,
+ height_shift_range=0.,
+ brightness_range=None,
+ shear_range=0.,
+ zoom_range=0.,
+ channel_shift_range=0.,
+ fill_mode='nearest',
+ cval=0.,
+ horizontal_flip=False,
+ vertical_flip=False,
+ rescale=None,
+ preprocessing_function=None,
+ data_format=None,
+ validation_split=0.0,
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.ImageDataGenerator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(ImageDataGenerator, self).__init__(
+ featurewise_center=featurewise_center,
+ samplewise_center=samplewise_center,
+ featurewise_std_normalization=featurewise_std_normalization,
+ samplewise_std_normalization=samplewise_std_normalization,
+ zca_whitening=zca_whitening,
+ zca_epsilon=zca_epsilon,
+ rotation_range=rotation_range,
+ width_shift_range=width_shift_range,
+ height_shift_range=height_shift_range,
+ brightness_range=brightness_range,
+ shear_range=shear_range,
+ zoom_range=zoom_range,
+ channel_shift_range=channel_shift_range,
+ fill_mode=fill_mode,
+ cval=cval,
+ horizontal_flip=horizontal_flip,
+ vertical_flip=vertical_flip,
+ rescale=rescale,
+ preprocessing_function=preprocessing_function,
+ data_format=data_format,
+ validation_split=validation_split,
+ **kwargs)
tf_export('keras.preprocessing.image.random_rotation')(random_rotation)
tf_export('keras.preprocessing.image.random_shift')(random_shift)
@@ -59,11 +530,4 @@ tf_export(
tf_export('keras.preprocessing.image.random_brightness')(random_brightness)
tf_export(
'keras.preprocessing.image.apply_affine_transform')(apply_affine_transform)
-tf_export('keras.preprocessing.image.array_to_img')(array_to_img)
-tf_export('keras.preprocessing.image.img_to_array')(img_to_array)
-tf_export('keras.preprocessing.image.save_img')(save_img)
tf_export('keras.preprocessing.image.load_img')(load_img)
-tf_export('keras.preprocessing.image.ImageDataGenerator')(ImageDataGenerator)
-tf_export('keras.preprocessing.image.Iterator')(Iterator)
-tf_export('keras.preprocessing.image.NumpyArrayIterator')(NumpyArrayIterator)
-tf_export('keras.preprocessing.image.DirectoryIterator')(DirectoryIterator)
diff --git a/tensorflow/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py
index 116d3108d9..f014668909 100644
--- a/tensorflow/python/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/preprocessing/sequence.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from keras_preprocessing import sequence
+from tensorflow.python.keras import utils
from tensorflow.python.util.tf_export import tf_export
pad_sequences = sequence.pad_sequences
@@ -28,11 +29,67 @@ make_sampling_table = sequence.make_sampling_table
skipgrams = sequence.skipgrams
# TODO(fchollet): consider making `_remove_long_seq` public.
_remove_long_seq = sequence._remove_long_seq # pylint: disable=protected-access
-TimeseriesGenerator = sequence.TimeseriesGenerator
+
+
+@tf_export('keras.preprocessing.sequence.TimeseriesGenerator')
+class TimeseriesGenerator(sequence.TimeseriesGenerator, utils.Sequence):
+ """Utility class for generating batches of temporal data.
+ This class takes in a sequence of data-points gathered at
+ equal intervals, along with time series parameters such as
+ stride, length of history, etc., to produce batches for
+ training/validation.
+ # Arguments
+ data: Indexable generator (such as list or Numpy array)
+ containing consecutive data points (timesteps).
+ The data should be at 2D, and axis 0 is expected
+ to be the time dimension.
+ targets: Targets corresponding to timesteps in `data`.
+ It should have same length as `data`.
+ length: Length of the output sequences (in number of timesteps).
+ sampling_rate: Period between successive individual timesteps
+ within sequences. For rate `r`, timesteps
+ `data[i]`, `data[i-r]`, ... `data[i - length]`
+ are used for create a sample sequence.
+ stride: Period between successive output sequences.
+ For stride `s`, consecutive output samples would
+ be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
+ start_index: Data points earlier than `start_index` will not be used
+ in the output sequences. This is useful to reserve part of the
+ data for test or validation.
+ end_index: Data points later than `end_index` will not be used
+ in the output sequences. This is useful to reserve part of the
+ data for test or validation.
+ shuffle: Whether to shuffle output samples,
+ or instead draw them in chronological order.
+ reverse: Boolean: if `true`, timesteps in each output sample will be
+ in reverse chronological order.
+ batch_size: Number of timeseries samples in each batch
+ (except maybe the last one).
+ # Returns
+ A [Sequence](/utils/#sequence) instance.
+ # Examples
+ ```python
+ from keras.preprocessing.sequence import TimeseriesGenerator
+ import numpy as np
+ data = np.array([[i] for i in range(50)])
+ targets = np.array([[i] for i in range(50)])
+ data_gen = TimeseriesGenerator(data, targets,
+ length=10, sampling_rate=2,
+ batch_size=2)
+ assert len(data_gen) == 20
+ batch_0 = data_gen[0]
+ x, y = batch_0
+ assert np.array_equal(x,
+ np.array([[[0], [2], [4], [6], [8]],
+ [[1], [3], [5], [7], [9]]]))
+ assert np.array_equal(y,
+ np.array([[10], [11]]))
+ ```
+ """
+ pass
+
tf_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences)
tf_export(
'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table)
tf_export('keras.preprocessing.sequence.skipgrams')(skipgrams)
-tf_export(
- 'keras.preprocessing.sequence.TimeseriesGenerator')(TimeseriesGenerator)
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 7001e566ce..64f0469482 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
+ "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index 73d11199d9..bc2f3516d1 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
+ "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake
index 4587bcf891..b7450c83de 100644
--- a/tensorflow/tools/ci_build/Dockerfile.cmake
+++ b/tensorflow/tools/ci_build/Dockerfile.cmake
@@ -28,8 +28,8 @@ RUN pip install --upgrade astor
RUN pip install --upgrade gast
RUN pip install --upgrade numpy
RUN pip install --upgrade termcolor
-RUN pip install keras_applications==1.0.4
-RUN pip install keras_preprocessing==1.0.2
+RUN pip install keras_applications==1.0.5
+RUN pip install keras_preprocessing==1.0.3
# Install golang
RUN apt-get install -t xenial-backports -y golang-1.9
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index bb316ecfc9..af478eded4 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -115,10 +115,10 @@ pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
# Keras
-pip2 install keras_applications==1.0.4 --no-deps
-pip3 install keras_applications==1.0.4 --no-deps
-pip2 install keras_preprocessing==1.0.2 --no-deps
-pip3 install keras_preprocessing==1.0.2 --no-deps
+pip2 install keras_applications==1.0.5 --no-deps
+pip3 install keras_applications==1.0.5 --no-deps
+pip2 install keras_preprocessing==1.0.3 --no-deps
+pip3 install keras_preprocessing==1.0.3 --no-deps
# Install last working version of setuptools.
pip2 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 15e4396ce3..93ea0c3db6 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -85,8 +85,8 @@ pip3.5 install --upgrade termcolor
pip3.5 install --upgrade setuptools==39.1.0
# Keras
-pip3.5 install keras_applications==1.0.4
-pip3.5 install keras_preprocessing==1.0.2
+pip3.5 install keras_applications==1.0.5
+pip3.5 install keras_preprocessing==1.0.3
# Install last working version of setuptools.
pip3.5 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index 0fc3eee71c..7a9eef7c64 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -101,7 +101,7 @@ pip3 install --upgrade termcolor
pip3 install --upgrade setuptools==39.1.0
# Keras
-pip3 install keras_applications==1.0.4
-pip3 install keras_preprocessing==1.0.2
+pip3 install keras_applications==1.0.5
+pip3 install keras_preprocessing==1.0.3
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh)
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index 2c31d784e5..0114ef9dbf 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -29,8 +29,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy==1.14.5 \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index bacdea72ce..aec5ca965e 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -33,8 +33,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
numpy==1.14.5 \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 4f89e3f701..ba421d9978 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -49,8 +49,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
numpy==1.14.5 \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
index 056b4755f4..eb139ec5f8 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
@@ -37,8 +37,8 @@ RUN pip --no-cache-dir install --upgrade \
RUN pip --no-cache-dir install \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy \
scipy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 2df770e525..371451d2aa 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -52,8 +52,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
index ab2eec1728..987b582d10 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
@@ -45,8 +45,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index aa0e0face1..806b8836c7 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -37,8 +37,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy==1.14.5 \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
index 69553302d8..641c9e3b16 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -38,8 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod
index 756716ee0e..2b11679f54 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod
@@ -38,8 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 5e179079c5..8cefbef82d 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -51,8 +51,8 @@ REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
- 'keras_applications == 1.0.4',
- 'keras_preprocessing == 1.0.2',
+ 'keras_applications >= 1.0.5',
+ 'keras_preprocessing >= 1.0.3',
'numpy >= 1.13.3, <= 1.14.5',
'six >= 1.10.0',
'protobuf >= 3.6.0',