aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-05-17 21:36:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 21:40:10 -0700
commit609b2ce3fe8ebecf4031670b8c2186468369b0ba (patch)
tree59d5eb7308ffc67a4602f9b028cdd45450f56777 /tensorflow/python/keras/utils
parentaca0458707fa63626c78acfeae2ade9ee78c54d1 (diff)
Move Keras code out of _impl folder and remove API files.
PiperOrigin-RevId: 197097430
Diffstat (limited to 'tensorflow/python/keras/utils')
-rw-r--r--tensorflow/python/keras/utils/__init__.py32
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py201
-rw-r--r--tensorflow/python/keras/utils/data_utils.py824
-rw-r--r--tensorflow/python/keras/utils/data_utils_test.py311
-rw-r--r--tensorflow/python/keras/utils/generic_utils.py561
-rw-r--r--tensorflow/python/keras/utils/generic_utils_test.py75
-rw-r--r--tensorflow/python/keras/utils/io_utils.py171
-rw-r--r--tensorflow/python/keras/utils/io_utils_test.py100
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py266
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py252
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py185
-rw-r--r--tensorflow/python/keras/utils/np_utils.py67
-rw-r--r--tensorflow/python/keras/utils/np_utils_test.py53
-rw-r--r--tensorflow/python/keras/utils/tf_utils.py154
-rw-r--r--tensorflow/python/keras/utils/vis_utils.py155
15 files changed, 3391 insertions, 16 deletions
diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py
index 2f74cf031d..7b5eecc153 100644
--- a/tensorflow/python/keras/utils/__init__.py
+++ b/tensorflow/python/keras/utils/__init__.py
@@ -18,22 +18,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
-from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
-from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer
-from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope
-from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
-from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
-from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects
-from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
-from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
-from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix
-from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model
-from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model
-from tensorflow.python.keras._impl.keras.utils.np_utils import normalize
-from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
-from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model
+from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
+from tensorflow.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras.utils.data_utils import Sequence
+from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer
+from tensorflow.python.keras.utils.generic_utils import custom_object_scope
+from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras.utils.generic_utils import get_custom_objects
+from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras.utils.io_utils import HDF5Matrix
+from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
+from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
+from tensorflow.python.keras.utils.np_utils import normalize
+from tensorflow.python.keras.utils.np_utils import to_categorical
+from tensorflow.python.keras.utils.vis_utils import plot_model
del absolute_import
del division
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
new file mode 100644
index 0000000000..5419e7ae05
--- /dev/null
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -0,0 +1,201 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities used by convolution layers.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import range # pylint: disable=redefined-builtin
+
+from tensorflow.python.keras import backend
+
+
+def convert_data_format(data_format, ndim):
+ if data_format == 'channels_last':
+ if ndim == 3:
+ return 'NWC'
+ elif ndim == 4:
+ return 'NHWC'
+ elif ndim == 5:
+ return 'NDHWC'
+ else:
+ raise ValueError('Input rank not supported:', ndim)
+ elif data_format == 'channels_first':
+ if ndim == 3:
+ return 'NCW'
+ elif ndim == 4:
+ return 'NCHW'
+ elif ndim == 5:
+ return 'NCDHW'
+ else:
+ raise ValueError('Input rank not supported:', ndim)
+ else:
+ raise ValueError('Invalid data_format:', data_format)
+
+
+def normalize_tuple(value, n, name):
+ """Transforms a single integer or iterable of integers into an integer tuple.
+
+ Arguments:
+ value: The value to validate and convert. Could an int, or any iterable
+ of ints.
+ n: The size of the tuple to be returned.
+ name: The name of the argument being validated, e.g. "strides" or
+ "kernel_size". This is only used to format error messages.
+
+ Returns:
+ A tuple of n integers.
+
+ Raises:
+ ValueError: If something else than an int/long or iterable thereof was
+ passed.
+ """
+ if isinstance(value, int):
+ return (value,) * n
+ else:
+ try:
+ value_tuple = tuple(value)
+ except TypeError:
+ raise ValueError('The `' + name + '` argument must be a tuple of ' +
+ str(n) + ' integers. Received: ' + str(value))
+ if len(value_tuple) != n:
+ raise ValueError('The `' + name + '` argument must be a tuple of ' +
+ str(n) + ' integers. Received: ' + str(value))
+ for single_value in value_tuple:
+ try:
+ int(single_value)
+ except (ValueError, TypeError):
+ raise ValueError('The `' + name + '` argument must be a tuple of ' +
+ str(n) + ' integers. Received: ' + str(value) + ' '
+ 'including element ' + str(single_value) + ' of type' +
+ ' ' + str(type(single_value)))
+ return value_tuple
+
+
+def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
+ """Determines output length of a convolution given input length.
+
+ Arguments:
+ input_length: integer.
+ filter_size: integer.
+ padding: one of "same", "valid", "full".
+ stride: integer.
+ dilation: dilation rate, integer.
+
+ Returns:
+ The output length (integer).
+ """
+ if input_length is None:
+ return None
+ assert padding in {'same', 'valid', 'full'}
+ dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
+ if padding == 'same':
+ output_length = input_length
+ elif padding == 'valid':
+ output_length = input_length - dilated_filter_size + 1
+ elif padding == 'full':
+ output_length = input_length + dilated_filter_size - 1
+ return (output_length + stride - 1) // stride
+
+
+def conv_input_length(output_length, filter_size, padding, stride):
+ """Determines input length of a convolution given output length.
+
+ Arguments:
+ output_length: integer.
+ filter_size: integer.
+ padding: one of "same", "valid", "full".
+ stride: integer.
+
+ Returns:
+ The input length (integer).
+ """
+ if output_length is None:
+ return None
+ assert padding in {'same', 'valid', 'full'}
+ if padding == 'same':
+ pad = filter_size // 2
+ elif padding == 'valid':
+ pad = 0
+ elif padding == 'full':
+ pad = filter_size - 1
+ return (output_length - 1) * stride - 2 * pad + filter_size
+
+
+def deconv_output_length(input_length, filter_size, padding, stride):
+ """Determines output length of a transposed convolution given input length.
+
+ Arguments:
+ input_length: integer.
+ filter_size: integer.
+ padding: one of "same", "valid", "full".
+ stride: integer.
+
+ Returns:
+ The output length (integer).
+ """
+ if input_length is None:
+ return None
+ input_length *= stride
+ if padding == 'valid':
+ input_length += max(filter_size - stride, 0)
+ elif padding == 'full':
+ input_length -= (stride + filter_size - 2)
+ return input_length
+
+
+def normalize_data_format(value):
+ if value is None:
+ value = backend.image_data_format()
+ data_format = value.lower()
+ if data_format not in {'channels_first', 'channels_last'}:
+ raise ValueError('The `data_format` argument must be one of '
+ '"channels_first", "channels_last". Received: ' +
+ str(value))
+ return data_format
+
+
+def normalize_padding(value):
+ padding = value.lower()
+ if padding not in {'valid', 'same', 'causal'}:
+ raise ValueError('The `padding` argument must be one of '
+ '"valid", "same" (or "causal", only for `Conv1D). '
+ 'Received: ' + str(padding))
+ return padding
+
+
+def convert_kernel(kernel):
+ """Converts a Numpy kernel matrix from Theano format to TensorFlow format.
+
+ Also works reciprocally, since the transformation is its own inverse.
+
+ Arguments:
+ kernel: Numpy array (3D, 4D or 5D).
+
+ Returns:
+ The converted kernel.
+
+ Raises:
+ ValueError: in case of invalid kernel shape or invalid data_format.
+ """
+ kernel = np.asarray(kernel)
+ if not 3 <= kernel.ndim <= 5:
+ raise ValueError('Invalid kernel shape:', kernel.shape)
+ slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
+ no_flip = (slice(None, None), slice(None, None))
+ slices[-2:] = no_flip
+ return np.copy(kernel[slices])
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
new file mode 100644
index 0000000000..a1f89d9d43
--- /dev/null
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -0,0 +1,824 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-import-not-at-top
+"""Utilities for file download and caching."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from abc import abstractmethod
+from contextlib import closing
+import hashlib
+import multiprocessing
+from multiprocessing.pool import ThreadPool
+import os
+import random
+import shutil
+import sys
+import tarfile
+import threading
+import time
+import traceback
+import zipfile
+
+import numpy as np
+import six
+from six.moves.urllib.error import HTTPError
+from six.moves.urllib.error import URLError
+from six.moves.urllib.request import urlopen
+
+from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.util.tf_export import tf_export
+
+
+try:
+ import queue
+except ImportError:
+ import Queue as queue
+
+
+if sys.version_info[0] == 2:
+
+ def urlretrieve(url, filename, reporthook=None, data=None):
+ """Replacement for `urlretrive` for Python 2.
+
+ Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
+ `urllib` module, known to have issues with proxy management.
+
+ Arguments:
+ url: url to retrieve.
+ filename: where to store the retrieved data locally.
+ reporthook: a hook function that will be called once
+ on establishment of the network connection and once
+ after each block read thereafter.
+ The hook will be passed three arguments;
+ a count of blocks transferred so far,
+ a block size in bytes, and the total size of the file.
+ data: `data` argument passed to `urlopen`.
+ """
+
+ def chunk_read(response, chunk_size=8192, reporthook=None):
+ content_type = response.info().get('Content-Length')
+ total_size = -1
+ if content_type is not None:
+ total_size = int(content_type.strip())
+ count = 0
+ while True:
+ chunk = response.read(chunk_size)
+ count += 1
+ if reporthook is not None:
+ reporthook(count, chunk_size, total_size)
+ if chunk:
+ yield chunk
+ else:
+ break
+
+ response = urlopen(url, data)
+ with open(filename, 'wb') as fd:
+ for chunk in chunk_read(response, reporthook=reporthook):
+ fd.write(chunk)
+else:
+ from six.moves.urllib.request import urlretrieve
+
+
+def _extract_archive(file_path, path='.', archive_format='auto'):
+ """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
+
+ Arguments:
+ file_path: path to the archive file
+ path: path to extract the archive file
+ archive_format: Archive format to try for extracting the file.
+ Options are 'auto', 'tar', 'zip', and None.
+ 'tar' includes tar, tar.gz, and tar.bz files.
+ The default 'auto' is ['tar', 'zip'].
+ None or an empty list will return no matches found.
+
+ Returns:
+ True if a match was found and an archive extraction was completed,
+ False otherwise.
+ """
+ if archive_format is None:
+ return False
+ if archive_format is 'auto':
+ archive_format = ['tar', 'zip']
+ if isinstance(archive_format, six.string_types):
+ archive_format = [archive_format]
+
+ for archive_type in archive_format:
+ if archive_type is 'tar':
+ open_fn = tarfile.open
+ is_match_fn = tarfile.is_tarfile
+ if archive_type is 'zip':
+ open_fn = zipfile.ZipFile
+ is_match_fn = zipfile.is_zipfile
+
+ if is_match_fn(file_path):
+ with open_fn(file_path) as archive:
+ try:
+ archive.extractall(path)
+ except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
+ if os.path.exists(path):
+ if os.path.isfile(path):
+ os.remove(path)
+ else:
+ shutil.rmtree(path)
+ raise
+ return True
+ return False
+
+
+@tf_export('keras.utils.get_file')
+def get_file(fname,
+ origin,
+ untar=False,
+ md5_hash=None,
+ file_hash=None,
+ cache_subdir='datasets',
+ hash_algorithm='auto',
+ extract=False,
+ archive_format='auto',
+ cache_dir=None):
+ """Downloads a file from a URL if it not already in the cache.
+
+ By default the file at the url `origin` is downloaded to the
+ cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
+ and given the filename `fname`. The final location of a file
+ `example.txt` would therefore be `~/.keras/datasets/example.txt`.
+
+ Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
+ Passing a hash will verify the file after download. The command line
+ programs `shasum` and `sha256sum` can compute the hash.
+
+ Arguments:
+ fname: Name of the file. If an absolute path `/path/to/file.txt` is
+ specified the file will be saved at that location.
+ origin: Original URL of the file.
+ untar: Deprecated in favor of 'extract'.
+ boolean, whether the file should be decompressed
+ md5_hash: Deprecated in favor of 'file_hash'.
+ md5 hash of the file for verification
+ file_hash: The expected hash string of the file after download.
+ The sha256 and md5 hash algorithms are both supported.
+ cache_subdir: Subdirectory under the Keras cache dir where the file is
+ saved. If an absolute path `/path/to/folder` is
+ specified the file will be saved at that location.
+ hash_algorithm: Select the hash algorithm to verify the file.
+ options are 'md5', 'sha256', and 'auto'.
+ The default 'auto' detects the hash algorithm in use.
+ extract: True tries extracting the file as an Archive, like tar or zip.
+ archive_format: Archive format to try for extracting the file.
+ Options are 'auto', 'tar', 'zip', and None.
+ 'tar' includes tar, tar.gz, and tar.bz files.
+ The default 'auto' is ['tar', 'zip'].
+ None or an empty list will return no matches found.
+ cache_dir: Location to store cached files, when None it
+ defaults to the [Keras
+ Directory](/faq/#where-is-the-keras-configuration-filed-stored).
+
+ Returns:
+ Path to the downloaded file
+ """
+ if cache_dir is None:
+ cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
+ if md5_hash is not None and file_hash is None:
+ file_hash = md5_hash
+ hash_algorithm = 'md5'
+ datadir_base = os.path.expanduser(cache_dir)
+ if not os.access(datadir_base, os.W_OK):
+ datadir_base = os.path.join('/tmp', '.keras')
+ datadir = os.path.join(datadir_base, cache_subdir)
+ if not os.path.exists(datadir):
+ os.makedirs(datadir)
+
+ if untar:
+ untar_fpath = os.path.join(datadir, fname)
+ fpath = untar_fpath + '.tar.gz'
+ else:
+ fpath = os.path.join(datadir, fname)
+
+ download = False
+ if os.path.exists(fpath):
+ # File found; verify integrity if a hash was provided.
+ if file_hash is not None:
+ if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
+ print('A local file was found, but it seems to be '
+ 'incomplete or outdated because the ' + hash_algorithm +
+ ' file hash does not match the original value of ' + file_hash +
+ ' so we will re-download the data.')
+ download = True
+ else:
+ download = True
+
+ if download:
+ print('Downloading data from', origin)
+
+ class ProgressTracker(object):
+ # Maintain progbar for the lifetime of download.
+ # This design was chosen for Python 2.7 compatibility.
+ progbar = None
+
+ def dl_progress(count, block_size, total_size):
+ if ProgressTracker.progbar is None:
+ if total_size is -1:
+ total_size = None
+ ProgressTracker.progbar = Progbar(total_size)
+ else:
+ ProgressTracker.progbar.update(count * block_size)
+
+ error_msg = 'URL fetch failure on {}: {} -- {}'
+ try:
+ try:
+ urlretrieve(origin, fpath, dl_progress)
+ except URLError as e:
+ raise Exception(error_msg.format(origin, e.errno, e.reason))
+ except HTTPError as e:
+ raise Exception(error_msg.format(origin, e.code, e.msg))
+ except (Exception, KeyboardInterrupt) as e:
+ if os.path.exists(fpath):
+ os.remove(fpath)
+ raise
+ ProgressTracker.progbar = None
+
+ if untar:
+ if not os.path.exists(untar_fpath):
+ _extract_archive(fpath, datadir, archive_format='tar')
+ return untar_fpath
+
+ if extract:
+ _extract_archive(fpath, datadir, archive_format)
+
+ return fpath
+
+
+def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
+ """Calculates a file sha256 or md5 hash.
+
+ Example:
+
+ ```python
+ >>> from keras.data_utils import _hash_file
+ >>> _hash_file('/path/to/file.zip')
+ 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
+ ```
+
+ Arguments:
+ fpath: path to the file being validated
+ algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
+ The default 'auto' detects the hash algorithm in use.
+ chunk_size: Bytes to read at a time, important for large files.
+
+ Returns:
+ The file hash
+ """
+ if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64):
+ hasher = hashlib.sha256()
+ else:
+ hasher = hashlib.md5()
+
+ with open(fpath, 'rb') as fpath_file:
+ for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
+ hasher.update(chunk)
+
+ return hasher.hexdigest()
+
+
+def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
+ """Validates a file against a sha256 or md5 hash.
+
+ Arguments:
+ fpath: path to the file being validated
+ file_hash: The expected hash string of the file.
+ The sha256 and md5 hash algorithms are both supported.
+ algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
+ The default 'auto' detects the hash algorithm in use.
+ chunk_size: Bytes to read at a time, important for large files.
+
+ Returns:
+ Whether the file is valid
+ """
+ if ((algorithm is 'sha256') or
+ (algorithm is 'auto' and len(file_hash) is 64)):
+ hasher = 'sha256'
+ else:
+ hasher = 'md5'
+
+ if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
+ return True
+ else:
+ return False
+
+
+@tf_export('keras.utils.Sequence')
+class Sequence(object):
+ """Base object for fitting to a sequence of data, such as a dataset.
+
+ Every `Sequence` must implements the `__getitem__` and the `__len__` methods.
+ If you want to modify your dataset between epochs you may implement
+ `on_epoch_end`.
+ The method `__getitem__` should return a complete batch.
+
+ # Notes
+
+ `Sequence` are a safer way to do multiprocessing. This structure guarantees
+ that the network will only train once
+ on each sample per epoch which is not the case with generators.
+
+ Examples:
+
+ ```python
+ from skimage.io import imread
+ from skimage.transform import resize
+ import numpy as np
+ import math
+
+ # Here, `x_set` is list of path to the images
+ # and `y_set` are the associated classes.
+
+ class CIFAR10Sequence(Sequence):
+
+ def __init__(self, x_set, y_set, batch_size):
+ self.x, self.y = x_set, y_set
+ self.batch_size = batch_size
+
+ def __len__(self):
+ return math.ceil(len(self.x) / self.batch_size)
+
+ def __getitem__(self, idx):
+ batch_x = self.x[idx * self.batch_size:(idx + 1) *
+ self.batch_size]
+ batch_y = self.y[idx * self.batch_size:(idx + 1) *
+ self.batch_size]
+
+ return np.array([
+ resize(imread(file_name), (200, 200))
+ for file_name in batch_x]), np.array(batch_y)
+ ```
+ """
+
+ @abstractmethod
+ def __getitem__(self, index):
+ """Gets batch at position `index`.
+
+ Arguments:
+ index: position of the batch in the Sequence.
+
+ Returns:
+ A batch
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def __len__(self):
+ """Number of batch in the Sequence.
+
+ Returns:
+ The number of batches in the Sequence.
+ """
+ raise NotImplementedError
+
+ def on_epoch_end(self):
+ """Method called at the end of every epoch.
+ """
+ pass
+
+ def __iter__(self):
+ """Creates an infinite generator that iterate over the Sequence.
+
+ Yields:
+ Sequence items.
+ """
+ while True:
+ for item in (self[i] for i in range(len(self))):
+ yield item
+
+
+# Global variables to be shared across processes
+_SHARED_SEQUENCES = {}
+# We use a Value to provide unique id to different processes.
+_SEQUENCE_COUNTER = None
+
+
+def init_pool(seqs):
+ global _SHARED_SEQUENCES
+ _SHARED_SEQUENCES = seqs
+
+
+def get_index(uid, i):
+ """Get the value from the Sequence `uid` at index `i`.
+
+ To allow multiple Sequences to be used at the same time, we use `uid` to
+ get a specific one. A single Sequence would cause the validation to
+ overwrite the training Sequence.
+
+ Arguments:
+ uid: int, Sequence identifier
+ i: index
+
+ Returns:
+ The value at index `i`.
+ """
+ return _SHARED_SEQUENCES[uid][i]
+
+
+@tf_export('keras.utils.SequenceEnqueuer')
+class SequenceEnqueuer(object):
+ """Base class to enqueue inputs.
+
+ The task of an Enqueuer is to use parallelism to speed up preprocessing.
+ This is done with processes or threads.
+
+ Examples:
+
+ ```python
+ enqueuer = SequenceEnqueuer(...)
+ enqueuer.start()
+ datas = enqueuer.get()
+ for data in datas:
+ # Use the inputs; training, evaluating, predicting.
+ # ... stop sometime.
+ enqueuer.close()
+ ```
+
+ The `enqueuer.get()` should be an infinite stream of datas.
+
+ """
+
+ @abstractmethod
+ def is_running(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def start(self, workers=1, max_queue_size=10):
+ """Starts the handler's workers.
+
+ Arguments:
+ workers: number of worker threads
+ max_queue_size: queue size
+ (when full, threads could block on `put()`).
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def stop(self, timeout=None):
+ """Stop running threads and wait for them to exit, if necessary.
+
+ Should be called by the same thread which called start().
+
+ Arguments:
+ timeout: maximum time to wait on thread.join()
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def get(self):
+ """Creates a generator to extract data from the queue.
+
+ Skip the data if it is `None`.
+
+ Returns:
+ Generator yielding tuples `(inputs, targets)`
+ or `(inputs, targets, sample_weights)`.
+ """
+ raise NotImplementedError
+
+
+class OrderedEnqueuer(SequenceEnqueuer):
+ """Builds a Enqueuer from a Sequence.
+
+ Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
+
+ Arguments:
+ sequence: A `keras.utils.data_utils.Sequence` object.
+ use_multiprocessing: use multiprocessing if True, otherwise threading
+ shuffle: whether to shuffle the data at the beginning of each epoch
+ """
+
+ def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
+ self.sequence = sequence
+ self.use_multiprocessing = use_multiprocessing
+
+ global _SEQUENCE_COUNTER
+ if _SEQUENCE_COUNTER is None:
+ try:
+ _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
+ except OSError:
+ # In this case the OS does not allow us to use
+ # multiprocessing. We resort to an int
+ # for enqueuer indexing.
+ _SEQUENCE_COUNTER = 0
+
+ if isinstance(_SEQUENCE_COUNTER, int):
+ self.uid = _SEQUENCE_COUNTER
+ _SEQUENCE_COUNTER += 1
+ else:
+ # Doing Multiprocessing.Value += x is not process-safe.
+ with _SEQUENCE_COUNTER.get_lock():
+ self.uid = _SEQUENCE_COUNTER.value
+ _SEQUENCE_COUNTER.value += 1
+
+ self.shuffle = shuffle
+ self.workers = 0
+ self.executor_fn = None
+ self.queue = None
+ self.run_thread = None
+ self.stop_signal = None
+
+ def is_running(self):
+ return self.stop_signal is not None and not self.stop_signal.is_set()
+
+ def start(self, workers=1, max_queue_size=10):
+ """Start the handler's workers.
+
+ Arguments:
+ workers: number of worker threads
+ max_queue_size: queue size
+ (when full, workers could block on `put()`)
+ """
+ if self.use_multiprocessing:
+ self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
+ workers, initializer=init_pool, initargs=(seqs,))
+ else:
+ # We do not need the init since it's threads.
+ self.executor_fn = lambda _: ThreadPool(workers)
+ self.workers = workers
+ self.queue = queue.Queue(max_queue_size)
+ self.stop_signal = threading.Event()
+ self.run_thread = threading.Thread(target=self._run)
+ self.run_thread.daemon = True
+ self.run_thread.start()
+
+ def _wait_queue(self):
+ """Wait for the queue to be empty."""
+ while True:
+ time.sleep(0.1)
+ if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
+ return
+
+ def _run(self):
+ """Submits request to the executor and queue the `Future` objects."""
+ sequence = list(range(len(self.sequence)))
+ self._send_sequence() # Share the initial sequence
+ while True:
+ if self.shuffle:
+ random.shuffle(sequence)
+
+ with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
+ for i in sequence:
+ if self.stop_signal.is_set():
+ return
+ self.queue.put(
+ executor.apply_async(get_index, (self.uid, i)), block=True)
+
+ # Done with the current epoch, waiting for the final batches
+ self._wait_queue()
+
+ if self.stop_signal.is_set():
+ # We're done
+ return
+
+ # Call the internal on epoch end.
+ self.sequence.on_epoch_end()
+ self._send_sequence() # Update the pool
+
+ def get(self):
+ """Creates a generator to extract data from the queue.
+
+ Skip the data if it is `None`.
+
+ Yields:
+ The next element in the queue, i.e. a tuple
+ `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
+ """
+ try:
+ while self.is_running():
+ inputs = self.queue.get(block=True).get()
+ self.queue.task_done()
+ if inputs is not None:
+ yield inputs
+ except Exception as e: # pylint: disable=broad-except
+ self.stop()
+ six.raise_from(StopIteration(e), e)
+
+ def _send_sequence(self):
+ """Send current Sequence to all workers."""
+ # For new processes that may spawn
+ _SHARED_SEQUENCES[self.uid] = self.sequence
+
+ def stop(self, timeout=None):
+ """Stops running threads and wait for them to exit, if necessary.
+
+ Should be called by the same thread which called `start()`.
+
+ Arguments:
+ timeout: maximum time to wait on `thread.join()`
+ """
+ self.stop_signal.set()
+ with self.queue.mutex:
+ self.queue.queue.clear()
+ self.queue.unfinished_tasks = 0
+ self.queue.not_full.notify()
+ self.run_thread.join(timeout)
+ _SHARED_SEQUENCES[self.uid] = None
+
+
+@tf_export('keras.utils.GeneratorEnqueuer')
+class GeneratorEnqueuer(SequenceEnqueuer):
+ """Builds a queue out of a data generator.
+
+ The provided generator can be finite in which case the class will throw
+ a `StopIteration` exception.
+
+ Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
+
+ Arguments:
+ generator: a generator function which yields data
+ use_multiprocessing: use multiprocessing if True, otherwise threading
+ wait_time: time to sleep in-between calls to `put()`
+ random_seed: Initial seed for workers,
+ will be incremented by one for each worker.
+ """
+
+ def __init__(self,
+ generator,
+ use_multiprocessing=False,
+ wait_time=0.05,
+ seed=None):
+ self.wait_time = wait_time
+ self._generator = generator
+ if os.name is 'nt' and use_multiprocessing is True:
+ # On Windows, avoid **SYSTEMATIC** error in `multiprocessing`:
+ # `TypeError: can't pickle generator objects`
+ # => Suggest multithreading instead of multiprocessing on Windows
+ raise ValueError('Using a generator with `use_multiprocessing=True`'
+ ' is not supported on Windows (no marshalling of'
+ ' generators across process boundaries). Instead,'
+ ' use single thread/process or multithreading.')
+ else:
+ self._use_multiprocessing = use_multiprocessing
+ self._threads = []
+ self._stop_event = None
+ self._manager = None
+ self.queue = None
+ self.seed = seed
+
+ def _data_generator_task(self):
+ if self._use_multiprocessing is False:
+ while not self._stop_event.is_set():
+ with self.genlock:
+ try:
+ if (self.queue is not None and
+ self.queue.qsize() < self.max_queue_size):
+ # On all OSes, avoid **SYSTEMATIC** error
+ # in multithreading mode:
+ # `ValueError: generator already executing`
+ # => Serialize calls to
+ # infinite iterator/generator's next() function
+ generator_output = next(self._generator)
+ self.queue.put((True, generator_output))
+ else:
+ time.sleep(self.wait_time)
+ except StopIteration:
+ break
+ except Exception as e: # pylint: disable=broad-except
+ # Can't pickle tracebacks.
+ # As a compromise, print the traceback and pickle None instead.
+ if not hasattr(e, '__traceback__'):
+ setattr(e, '__traceback__', sys.exc_info()[2])
+ self.queue.put((False, e))
+ self._stop_event.set()
+ break
+ else:
+ while not self._stop_event.is_set():
+ try:
+ if (self.queue is not None and
+ self.queue.qsize() < self.max_queue_size):
+ generator_output = next(self._generator)
+ self.queue.put((True, generator_output))
+ else:
+ time.sleep(self.wait_time)
+ except StopIteration:
+ break
+ except Exception as e: # pylint: disable=broad-except
+ # Can't pickle tracebacks.
+ # As a compromise, print the traceback and pickle None instead.
+ traceback.print_exc()
+ setattr(e, '__traceback__', None)
+ self.queue.put((False, e))
+ self._stop_event.set()
+ break
+
+ def start(self, workers=1, max_queue_size=10):
+ """Kicks off threads which add data from the generator into the queue.
+
+ Arguments:
+ workers: number of worker threads
+ max_queue_size: queue size
+ (when full, threads could block on `put()`)
+ """
+ try:
+ self.max_queue_size = max_queue_size
+ if self._use_multiprocessing:
+ self._manager = multiprocessing.Manager()
+ self.queue = self._manager.Queue(maxsize=max_queue_size)
+ self._stop_event = multiprocessing.Event()
+ else:
+ # On all OSes, avoid **SYSTEMATIC** error in multithreading mode:
+ # `ValueError: generator already executing`
+ # => Serialize calls to infinite iterator/generator's next() function
+ self.genlock = threading.Lock()
+ self.queue = queue.Queue(maxsize=max_queue_size)
+ self._stop_event = threading.Event()
+
+ for _ in range(workers):
+ if self._use_multiprocessing:
+ # Reset random seed else all children processes
+ # share the same seed
+ np.random.seed(self.seed)
+ thread = multiprocessing.Process(target=self._data_generator_task)
+ thread.daemon = True
+ if self.seed is not None:
+ self.seed += 1
+ else:
+ thread = threading.Thread(target=self._data_generator_task)
+ self._threads.append(thread)
+ thread.start()
+ except:
+ self.stop()
+ raise
+
+ def is_running(self):
+ return self._stop_event is not None and not self._stop_event.is_set()
+
+ def stop(self, timeout=None):
+ """Stops running threads and wait for them to exit, if necessary.
+
+ Should be called by the same thread which called `start()`.
+
+ Arguments:
+ timeout: maximum time to wait on `thread.join()`.
+ """
+ if self.is_running():
+ self._stop_event.set()
+
+ for thread in self._threads:
+ if self._use_multiprocessing:
+ if thread.is_alive():
+ thread.terminate()
+ else:
+ # The thread.is_alive() test is subject to a race condition:
+ # the thread could terminate right after the test and before the
+ # join, rendering this test meaningless -> Call thread.join()
+ # always, which is ok no matter what the status of the thread.
+ thread.join(timeout)
+
+ if self._manager:
+ self._manager.shutdown()
+
+ self._threads = []
+ self._stop_event = None
+ self.queue = None
+
+ def get(self):
+ """Creates a generator to extract data from the queue.
+
+ Skip the data if it is `None`.
+
+ Yields:
+ The next element in the queue, i.e. a tuple
+ `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
+ """
+ while self.is_running():
+ if not self.queue.empty():
+ success, value = self.queue.get()
+ # Rethrow any exceptions found in the queue
+ if not success:
+ six.reraise(value.__class__, value, value.__traceback__)
+ # Yield regular values
+ if value is not None:
+ yield value
+ else:
+ all_finished = all([not thread.is_alive() for thread in self._threads])
+ if all_finished and self.queue.empty():
+ raise StopIteration()
+ else:
+ time.sleep(self.wait_time)
+
+ # Make sure to rethrow the first exception in the queue, if any
+ while not self.queue.empty():
+ success, value = self.queue.get()
+ if not success:
+ six.reraise(value.__class__, value, value.__traceback__)
diff --git a/tensorflow/python/keras/utils/data_utils_test.py b/tensorflow/python/keras/utils/data_utils_test.py
new file mode 100644
index 0000000000..395df7e0e7
--- /dev/null
+++ b/tensorflow/python/keras/utils/data_utils_test.py
@@ -0,0 +1,311 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for data_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from itertools import cycle
+import os
+import tarfile
+import threading
+import unittest
+import zipfile
+
+import numpy as np
+from six.moves.urllib.parse import urljoin
+from six.moves.urllib.request import pathname2url
+
+from tensorflow.python import keras
+from tensorflow.python.platform import test
+
+
+class TestGetFileAndValidateIt(test.TestCase):
+
+ def test_get_file_and_validate_it(self):
+ """Tests get_file from a url, plus extraction and validation.
+ """
+ dest_dir = self.get_temp_dir()
+ orig_dir = self.get_temp_dir()
+
+ text_file_path = os.path.join(orig_dir, 'test.txt')
+ zip_file_path = os.path.join(orig_dir, 'test.zip')
+ tar_file_path = os.path.join(orig_dir, 'test.tar.gz')
+
+ with open(text_file_path, 'w') as text_file:
+ text_file.write('Float like a butterfly, sting like a bee.')
+
+ with tarfile.open(tar_file_path, 'w:gz') as tar_file:
+ tar_file.add(text_file_path)
+
+ with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
+ zip_file.write(text_file_path)
+
+ origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path)))
+
+ path = keras.utils.data_utils.get_file('test.txt', origin,
+ untar=True, cache_subdir=dest_dir)
+ filepath = path + '.tar.gz'
+ hashval_sha256 = keras.utils.data_utils._hash_file(filepath)
+ hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5')
+ path = keras.utils.data_utils.get_file(
+ 'test.txt', origin, md5_hash=hashval_md5,
+ untar=True, cache_subdir=dest_dir)
+ path = keras.utils.data_utils.get_file(
+ filepath, origin, file_hash=hashval_sha256,
+ extract=True, cache_subdir=dest_dir)
+ self.assertTrue(os.path.exists(filepath))
+ self.assertTrue(keras.utils.data_utils.validate_file(filepath,
+ hashval_sha256))
+ self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5))
+ os.remove(filepath)
+
+ origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path)))
+
+ hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path)
+ hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path,
+ algorithm='md5')
+ path = keras.utils.data_utils.get_file(
+ 'test', origin, md5_hash=hashval_md5,
+ extract=True, cache_subdir=dest_dir)
+ path = keras.utils.data_utils.get_file(
+ 'test', origin, file_hash=hashval_sha256,
+ extract=True, cache_subdir=dest_dir)
+ self.assertTrue(os.path.exists(path))
+ self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256))
+ self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5))
+
+
+class ThreadsafeIter(object):
+
+ def __init__(self, it):
+ self.it = it
+ self.lock = threading.Lock()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.next()
+
+ def next(self):
+ with self.lock:
+ return next(self.it)
+
+
+def threadsafe_generator(f):
+
+ def g(*a, **kw):
+ return ThreadsafeIter(f(*a, **kw))
+
+ return g
+
+
+class TestSequence(keras.utils.data_utils.Sequence):
+
+ def __init__(self, shape, value=1.):
+ self.shape = shape
+ self.inner = value
+
+ def __getitem__(self, item):
+ return np.ones(self.shape, dtype=np.uint32) * item * self.inner
+
+ def __len__(self):
+ return 100
+
+ def on_epoch_end(self):
+ self.inner *= 5.0
+
+
+class FaultSequence(keras.utils.data_utils.Sequence):
+
+ def __getitem__(self, item):
+ raise IndexError(item, 'item is not present')
+
+ def __len__(self):
+ return 100
+
+
+@threadsafe_generator
+def create_generator_from_sequence_threads(ds):
+ for i in cycle(range(len(ds))):
+ yield ds[i]
+
+
+def create_generator_from_sequence_pcs(ds):
+ for i in cycle(range(len(ds))):
+ yield ds[i]
+
+
+class TestEnqueuers(test.TestCase):
+
+ def test_generator_enqueuer_threads(self):
+ enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
+ create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])),
+ use_multiprocessing=False)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ acc = []
+ for _ in range(100):
+ acc.append(int(next(gen_output)[0, 0, 0, 0]))
+
+ self.assertEqual(len(set(acc) - set(range(100))), 0)
+ enqueuer.stop()
+
+ @unittest.skipIf(
+ os.name == 'nt',
+ 'use_multiprocessing=True does not work on windows properly.')
+ def test_generator_enqueuer_processes(self):
+ enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
+ create_generator_from_sequence_pcs(TestSequence([3, 200, 200, 3])),
+ use_multiprocessing=True)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ acc = []
+ for _ in range(100):
+ acc.append(int(next(gen_output)[0, 0, 0, 0]))
+ self.assertNotEqual(acc, list(range(100)))
+ enqueuer.stop()
+
+ def test_generator_enqueuer_fail_threads(self):
+ enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
+ create_generator_from_sequence_threads(FaultSequence()),
+ use_multiprocessing=False)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ with self.assertRaises(IndexError):
+ next(gen_output)
+
+ @unittest.skipIf(
+ os.name == 'nt',
+ 'use_multiprocessing=True does not work on windows properly.')
+ def test_generator_enqueuer_fail_processes(self):
+ enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
+ create_generator_from_sequence_pcs(FaultSequence()),
+ use_multiprocessing=True)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ with self.assertRaises(IndexError):
+ next(gen_output)
+
+ def test_ordered_enqueuer_threads(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ acc = []
+ for _ in range(100):
+ acc.append(next(gen_output)[0, 0, 0, 0])
+ self.assertEqual(acc, list(range(100)))
+ enqueuer.stop()
+
+ def test_ordered_enqueuer_processes(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ acc = []
+ for _ in range(100):
+ acc.append(next(gen_output)[0, 0, 0, 0])
+ self.assertEqual(acc, list(range(100)))
+ enqueuer.stop()
+
+ def test_ordered_enqueuer_fail_threads(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ FaultSequence(), use_multiprocessing=False)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ with self.assertRaises(StopIteration):
+ next(gen_output)
+
+ def test_ordered_enqueuer_fail_processes(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ FaultSequence(), use_multiprocessing=True)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ with self.assertRaises(StopIteration):
+ next(gen_output)
+
+ def test_on_epoch_end_processes(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ acc = []
+ for _ in range(200):
+ acc.append(next(gen_output)[0, 0, 0, 0])
+ # Check that order was keep in GeneratorEnqueuer with processes
+ self.assertEqual(acc[100:], list([k * 5 for k in range(100)]))
+ enqueuer.stop()
+
+ def test_context_switch(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
+ enqueuer2 = keras.utils.data_utils.OrderedEnqueuer(
+ TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True)
+ enqueuer.start(3, 10)
+ enqueuer2.start(3, 10)
+ gen_output = enqueuer.get()
+ gen_output2 = enqueuer2.get()
+ acc = []
+ for _ in range(100):
+ acc.append(next(gen_output)[0, 0, 0, 0])
+ self.assertEqual(acc[-1], 99)
+ # One epoch is completed so enqueuer will switch the Sequence
+
+ acc = []
+ for _ in range(100):
+ acc.append(next(gen_output2)[0, 0, 0, 0])
+ self.assertEqual(acc[-1], 99 * 15)
+ # One epoch has been completed so enqueuer2 will switch
+
+ # Be sure that both Sequence were updated
+ self.assertEqual(next(gen_output)[0, 0, 0, 0], 0)
+ self.assertEqual(next(gen_output)[0, 0, 0, 0], 5)
+ self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0)
+ self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5)
+
+ # Tear down everything
+ enqueuer.stop()
+ enqueuer2.stop()
+
+ def test_on_epoch_end_threads(self):
+ enqueuer = keras.utils.data_utils.OrderedEnqueuer(
+ TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
+ enqueuer.start(3, 10)
+ gen_output = enqueuer.get()
+ acc = []
+ for _ in range(100):
+ acc.append(next(gen_output)[0, 0, 0, 0])
+ acc = []
+ for _ in range(100):
+ acc.append(next(gen_output)[0, 0, 0, 0])
+ # Check that order was keep in GeneratorEnqueuer with processes
+ self.assertEqual(acc, list([k * 5 for k in range(100)]))
+ enqueuer.stop()
+
+
+if __name__ == '__main__':
+ # Bazel sets these environment variables to very long paths.
+ # Tempfile uses them to create long paths, and in turn multiprocessing
+ # library tries to create sockets named after paths. Delete whatever bazel
+ # writes to these to avoid tests failing due to socket addresses being too
+ # long.
+ for var in ('TMPDIR', 'TMP', 'TEMP'):
+ if var in os.environ:
+ del os.environ[var]
+
+ test.main()
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
new file mode 100644
index 0000000000..a69893955f
--- /dev/null
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -0,0 +1,561 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python utilities required by Keras."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import binascii
+import codecs
+import marshal
+import os
+import re
+import sys
+import time
+import types as python_types
+
+import numpy as np
+import six
+
+from tensorflow.python.util import nest
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.tf_export import tf_export
+
+_GLOBAL_CUSTOM_OBJECTS = {}
+
+
+@tf_export('keras.utils.CustomObjectScope')
+class CustomObjectScope(object):
+ """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
+
+ Code within a `with` statement will be able to access custom objects
+ by name. Changes to global custom objects persist
+ within the enclosing `with` statement. At end of the `with` statement,
+ global custom objects are reverted to state
+ at beginning of the `with` statement.
+
+ Example:
+
+ Consider a custom object `MyObject` (e.g. a class):
+
+ ```python
+ with CustomObjectScope({'MyObject':MyObject}):
+ layer = Dense(..., kernel_regularizer='MyObject')
+ # save, load, etc. will recognize custom object by name
+ ```
+ """
+
+ def __init__(self, *args):
+ self.custom_objects = args
+ self.backup = None
+
+ def __enter__(self):
+ self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
+ for objects in self.custom_objects:
+ _GLOBAL_CUSTOM_OBJECTS.update(objects)
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ _GLOBAL_CUSTOM_OBJECTS.clear()
+ _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
+
+
+@tf_export('keras.utils.custom_object_scope')
+def custom_object_scope(*args):
+ """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
+
+ Convenience wrapper for `CustomObjectScope`.
+ Code within a `with` statement will be able to access custom objects
+ by name. Changes to global custom objects persist
+ within the enclosing `with` statement. At end of the `with` statement,
+ global custom objects are reverted to state
+ at beginning of the `with` statement.
+
+ Example:
+
+ Consider a custom object `MyObject`
+
+ ```python
+ with custom_object_scope({'MyObject':MyObject}):
+ layer = Dense(..., kernel_regularizer='MyObject')
+ # save, load, etc. will recognize custom object by name
+ ```
+
+ Arguments:
+ *args: Variable length list of dictionaries of name,
+ class pairs to add to custom objects.
+
+ Returns:
+ Object of type `CustomObjectScope`.
+ """
+ return CustomObjectScope(*args)
+
+
+@tf_export('keras.utils.get_custom_objects')
+def get_custom_objects():
+ """Retrieves a live reference to the global dictionary of custom objects.
+
+ Updating and clearing custom objects using `custom_object_scope`
+ is preferred, but `get_custom_objects` can
+ be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
+
+ Example:
+
+ ```python
+ get_custom_objects().clear()
+ get_custom_objects()['MyObject'] = MyObject
+ ```
+
+ Returns:
+ Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
+ """
+ return _GLOBAL_CUSTOM_OBJECTS
+
+
+@tf_export('keras.utils.serialize_keras_object')
+def serialize_keras_object(instance):
+ _, instance = tf_decorator.unwrap(instance)
+ if instance is None:
+ return None
+ if hasattr(instance, 'get_config'):
+ return {
+ 'class_name': instance.__class__.__name__,
+ 'config': instance.get_config()
+ }
+ if hasattr(instance, '__name__'):
+ return instance.__name__
+ else:
+ raise ValueError('Cannot serialize', instance)
+
+
+@tf_export('keras.utils.deserialize_keras_object')
+def deserialize_keras_object(identifier,
+ module_objects=None,
+ custom_objects=None,
+ printable_module_name='object'):
+ if isinstance(identifier, dict):
+ # In this case we are dealing with a Keras config dictionary.
+ config = identifier
+ if 'class_name' not in config or 'config' not in config:
+ raise ValueError('Improper config format: ' + str(config))
+ class_name = config['class_name']
+ if custom_objects and class_name in custom_objects:
+ cls = custom_objects[class_name]
+ elif class_name in _GLOBAL_CUSTOM_OBJECTS:
+ cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
+ else:
+ module_objects = module_objects or {}
+ cls = module_objects.get(class_name)
+ if cls is None:
+ raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
+ if hasattr(cls, 'from_config'):
+ arg_spec = tf_inspect.getargspec(cls.from_config)
+ custom_objects = custom_objects or {}
+
+ if 'custom_objects' in arg_spec.args:
+ return cls.from_config(
+ config['config'],
+ custom_objects=dict(
+ list(_GLOBAL_CUSTOM_OBJECTS.items()) +
+ list(custom_objects.items())))
+ with CustomObjectScope(custom_objects):
+ return cls.from_config(config['config'])
+ else:
+ # Then `cls` may be a function returning a class.
+ # in this case by convention `config` holds
+ # the kwargs of the function.
+ custom_objects = custom_objects or {}
+ with CustomObjectScope(custom_objects):
+ return cls(**config['config'])
+ elif isinstance(identifier, six.string_types):
+ function_name = identifier
+ if custom_objects and function_name in custom_objects:
+ fn = custom_objects.get(function_name)
+ elif function_name in _GLOBAL_CUSTOM_OBJECTS:
+ fn = _GLOBAL_CUSTOM_OBJECTS[function_name]
+ else:
+ fn = module_objects.get(function_name)
+ if fn is None:
+ raise ValueError('Unknown ' + printable_module_name + ':' +
+ function_name)
+ return fn
+ else:
+ raise ValueError('Could not interpret serialized ' + printable_module_name +
+ ': ' + identifier)
+
+
+def func_dump(func):
+ """Serializes a user defined function.
+
+ Arguments:
+ func: the function to serialize.
+
+ Returns:
+ A tuple `(code, defaults, closure)`.
+ """
+ if os.name == 'nt':
+ raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
+ code = codecs.encode(raw_code, 'base64').decode('ascii')
+ else:
+ raw_code = marshal.dumps(func.__code__)
+ code = codecs.encode(raw_code, 'base64').decode('ascii')
+ defaults = func.__defaults__
+ if func.__closure__:
+ closure = tuple(c.cell_contents for c in func.__closure__)
+ else:
+ closure = None
+ return code, defaults, closure
+
+
+def func_load(code, defaults=None, closure=None, globs=None):
+ """Deserializes a user defined function.
+
+ Arguments:
+ code: bytecode of the function.
+ defaults: defaults of the function.
+ closure: closure of the function.
+ globs: dictionary of global objects.
+
+ Returns:
+ A function object.
+ """
+ if isinstance(code, (tuple, list)): # unpack previous dump
+ code, defaults, closure = code
+ if isinstance(defaults, list):
+ defaults = tuple(defaults)
+
+ def ensure_value_to_cell(value):
+ """Ensures that a value is converted to a python cell object.
+
+ Arguments:
+ value: Any value that needs to be casted to the cell type
+
+ Returns:
+ A value wrapped as a cell object (see function "func_load")
+ """
+ def dummy_fn():
+ # pylint: disable=pointless-statement
+ value # just access it so it gets captured in .__closure__
+
+ cell_value = dummy_fn.__closure__[0]
+ if not isinstance(value, type(cell_value)):
+ return cell_value
+ else:
+ return value
+
+ if closure is not None:
+ closure = tuple(ensure_value_to_cell(_) for _ in closure)
+ try:
+ raw_code = codecs.decode(code.encode('ascii'), 'base64')
+ except (UnicodeEncodeError, binascii.Error):
+ raw_code = code.encode('raw_unicode_escape')
+ code = marshal.loads(raw_code)
+ if globs is None:
+ globs = globals()
+ return python_types.FunctionType(
+ code, globs, name=code.co_name, argdefs=defaults, closure=closure)
+
+
+def has_arg(fn, name, accept_all=False):
+ """Checks if a callable accepts a given keyword argument.
+
+ Arguments:
+ fn: Callable to inspect.
+ name: Check if `fn` can be called with `name` as a keyword argument.
+ accept_all: What to return if there is no parameter called `name`
+ but the function accepts a `**kwargs` argument.
+
+ Returns:
+ bool, whether `fn` accepts a `name` keyword argument.
+ """
+ arg_spec = tf_inspect.getargspec(fn)
+ if accept_all and arg_spec.keywords is not None:
+ return True
+ return name in arg_spec.args
+
+
+@tf_export('keras.utils.Progbar')
+class Progbar(object):
+ """Displays a progress bar.
+
+ Arguments:
+ target: Total number of steps expected, None if unknown.
+ width: Progress bar width on screen.
+ verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
+ stateful_metrics: Iterable of string names of metrics that
+ should *not* be averaged over time. Metrics in this list
+ will be displayed as-is. All others will be averaged
+ by the progbar before display.
+ interval: Minimum visual progress update interval (in seconds).
+ """
+
+ def __init__(self, target, width=30, verbose=1, interval=0.05,
+ stateful_metrics=None):
+ self.target = target
+ self.width = width
+ self.verbose = verbose
+ self.interval = interval
+ if stateful_metrics:
+ self.stateful_metrics = set(stateful_metrics)
+ else:
+ self.stateful_metrics = set()
+
+ self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
+ sys.stdout.isatty()) or
+ 'ipykernel' in sys.modules or
+ 'posix' in sys.modules)
+ self._total_width = 0
+ self._seen_so_far = 0
+ # We use a dict + list to avoid garbage collection
+ # issues found in OrderedDict
+ self._values = {}
+ self._values_order = []
+ self._start = time.time()
+ self._last_update = 0
+
+ def update(self, current, values=None):
+ """Updates the progress bar.
+
+ Arguments:
+ current: Index of current step.
+ values: List of tuples:
+ `(name, value_for_last_step)`.
+ If `name` is in `stateful_metrics`,
+ `value_for_last_step` will be displayed as-is.
+ Else, an average of the metric over time will be displayed.
+ """
+ values = values or []
+ for k, v in values:
+ if k not in self._values_order:
+ self._values_order.append(k)
+ if k not in self.stateful_metrics:
+ if k not in self._values:
+ self._values[k] = [v * (current - self._seen_so_far),
+ current - self._seen_so_far]
+ else:
+ self._values[k][0] += v * (current - self._seen_so_far)
+ self._values[k][1] += (current - self._seen_so_far)
+ else:
+ # Stateful metrics output a numeric value. This representation
+ # means "take an average from a single value" but keeps the
+ # numeric formatting.
+ self._values[k] = [v, 1]
+ self._seen_so_far = current
+
+ now = time.time()
+ info = ' - %.0fs' % (now - self._start)
+ if self.verbose == 1:
+ if (now - self._last_update < self.interval and
+ self.target is not None and current < self.target):
+ return
+
+ prev_total_width = self._total_width
+ if self._dynamic_display:
+ sys.stdout.write('\b' * prev_total_width)
+ sys.stdout.write('\r')
+ else:
+ sys.stdout.write('\n')
+
+ if self.target is not None:
+ numdigits = int(np.floor(np.log10(self.target))) + 1
+ barstr = '%%%dd/%d [' % (numdigits, self.target)
+ bar = barstr % current
+ prog = float(current) / self.target
+ prog_width = int(self.width * prog)
+ if prog_width > 0:
+ bar += ('=' * (prog_width - 1))
+ if current < self.target:
+ bar += '>'
+ else:
+ bar += '='
+ bar += ('.' * (self.width - prog_width))
+ bar += ']'
+ else:
+ bar = '%7d/Unknown' % current
+
+ self._total_width = len(bar)
+ sys.stdout.write(bar)
+
+ if current:
+ time_per_unit = (now - self._start) / current
+ else:
+ time_per_unit = 0
+ if self.target is not None and current < self.target:
+ eta = time_per_unit * (self.target - current)
+ if eta > 3600:
+ eta_format = '%d:%02d:%02d' % (eta // 3600,
+ (eta % 3600) // 60,
+ eta % 60)
+ elif eta > 60:
+ eta_format = '%d:%02d' % (eta // 60, eta % 60)
+ else:
+ eta_format = '%ds' % eta
+
+ info = ' - ETA: %s' % eta_format
+ else:
+ if time_per_unit >= 1:
+ info += ' %.0fs/step' % time_per_unit
+ elif time_per_unit >= 1e-3:
+ info += ' %.0fms/step' % (time_per_unit * 1e3)
+ else:
+ info += ' %.0fus/step' % (time_per_unit * 1e6)
+
+ for k in self._values_order:
+ info += ' - %s:' % k
+ if isinstance(self._values[k], list):
+ avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
+ if abs(avg) > 1e-3:
+ info += ' %.4f' % avg
+ else:
+ info += ' %.4e' % avg
+ else:
+ info += ' %s' % self._values[k]
+
+ self._total_width += len(info)
+ if prev_total_width > self._total_width:
+ info += (' ' * (prev_total_width - self._total_width))
+
+ if self.target is not None and current >= self.target:
+ info += '\n'
+
+ sys.stdout.write(info)
+ sys.stdout.flush()
+
+ elif self.verbose == 2:
+ if self.target is None or current >= self.target:
+ for k in self._values_order:
+ info += ' - %s:' % k
+ avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
+ if avg > 1e-3:
+ info += ' %.4f' % avg
+ else:
+ info += ' %.4e' % avg
+ info += '\n'
+
+ sys.stdout.write(info)
+ sys.stdout.flush()
+
+ self._last_update = now
+
+ def add(self, n, values=None):
+ self.update(self._seen_so_far + n, values)
+
+
+def make_batches(size, batch_size):
+ """Returns a list of batch indices (tuples of indices).
+
+ Arguments:
+ size: Integer, total size of the data to slice into batches.
+ batch_size: Integer, batch size.
+
+ Returns:
+ A list of tuples of array indices.
+ """
+ num_batches = int(np.ceil(size / float(batch_size)))
+ return [(i * batch_size, min(size, (i + 1) * batch_size))
+ for i in range(0, num_batches)]
+
+
+def slice_arrays(arrays, start=None, stop=None):
+ """Slice an array or list of arrays.
+
+ This takes an array-like, or a list of
+ array-likes, and outputs:
+ - arrays[start:stop] if `arrays` is an array-like
+ - [x[start:stop] for x in arrays] if `arrays` is a list
+
+ Can also work on list/array of indices: `slice_arrays(x, indices)`
+
+ Arguments:
+ arrays: Single array or list of arrays.
+ start: can be an integer index (start index)
+ or a list/array of indices
+ stop: integer (stop index); should be None if
+ `start` was a list.
+
+ Returns:
+ A slice of the array(s).
+
+ Raises:
+ ValueError: If the value of start is a list and stop is not None.
+ """
+ if arrays is None:
+ return [None]
+ if isinstance(start, list) and stop is not None:
+ raise ValueError('The stop argument has to be None if the value of start '
+ 'is a list.')
+ elif isinstance(arrays, list):
+ if hasattr(start, '__len__'):
+ # hdf5 datasets only support list objects as indices
+ if hasattr(start, 'shape'):
+ start = start.tolist()
+ return [None if x is None else x[start] for x in arrays]
+ else:
+ return [None if x is None else x[start:stop] for x in arrays]
+ else:
+ if hasattr(start, '__len__'):
+ if hasattr(start, 'shape'):
+ start = start.tolist()
+ return arrays[start]
+ elif hasattr(start, '__getitem__'):
+ return arrays[start:stop]
+ else:
+ return [None]
+
+
+def to_list(x):
+ """Normalizes a list/tensor into a list.
+
+ If a tensor is passed, we return
+ a list of size 1 containing the tensor.
+
+ Arguments:
+ x: target object to be normalized.
+
+ Returns:
+ A list.
+ """
+ if isinstance(x, list):
+ return x
+ return [x]
+
+
+def object_list_uid(object_list):
+ """Creates a single string from object ids."""
+ object_list = nest.flatten(object_list)
+ return ', '.join([str(abs(id(x))) for x in object_list])
+
+
+def to_snake_case(name):
+ intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
+ insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
+ # If the class is private the name starts with "_" which is not secure
+ # for creating scopes. We prefix the name with "private" in this case.
+ if insecure[0] != '_':
+ return insecure
+ return 'private' + insecure
+
+
+def is_all_none(iterable_or_element):
+ if not isinstance(iterable_or_element, (list, tuple)):
+ iterable = [iterable_or_element]
+ else:
+ iterable = iterable_or_element
+ # We cannot use Python's `any` because the iterable may return Tensors.
+ for element in iterable:
+ if element is not None:
+ return False
+ return True
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
new file mode 100644
index 0000000000..87bc19eb37
--- /dev/null
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -0,0 +1,75 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras generic Python utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.platform import test
+
+
+class HasArgTest(test.TestCase):
+
+ def test_has_arg(self):
+
+ def f_x(x):
+ return x
+
+ def f_x_args(x, *args):
+ _ = args
+ return x
+
+ def f_x_kwargs(x, **kwargs):
+ _ = kwargs
+ return x
+
+ self.assertTrue(keras.utils.generic_utils.has_arg(
+ f_x, 'x', accept_all=False))
+ self.assertFalse(keras.utils.generic_utils.has_arg(
+ f_x, 'y', accept_all=False))
+ self.assertTrue(keras.utils.generic_utils.has_arg(
+ f_x_args, 'x', accept_all=False))
+ self.assertFalse(keras.utils.generic_utils.has_arg(
+ f_x_args, 'y', accept_all=False))
+ self.assertTrue(keras.utils.generic_utils.has_arg(
+ f_x_kwargs, 'x', accept_all=False))
+ self.assertFalse(keras.utils.generic_utils.has_arg(
+ f_x_kwargs, 'y', accept_all=False))
+ self.assertTrue(keras.utils.generic_utils.has_arg(
+ f_x_kwargs, 'y', accept_all=True))
+
+
+class TestCustomObjectScope(test.TestCase):
+
+ def test_custom_object_scope(self):
+
+ def custom_fn():
+ pass
+
+ class CustomClass(object):
+ pass
+
+ with keras.utils.generic_utils.custom_object_scope(
+ {'CustomClass': CustomClass, 'custom_fn': custom_fn}):
+ act = keras.activations.get('custom_fn')
+ self.assertEqual(act, custom_fn)
+ cl = keras.regularizers.get('CustomClass')
+ self.assertEqual(cl.__class__, CustomClass)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py
new file mode 100644
index 0000000000..f82e3277de
--- /dev/null
+++ b/tensorflow/python/keras/utils/io_utils.py
@@ -0,0 +1,171 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-import-not-at-top
+"""Utilities related to disk I/O."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+
+import numpy as np
+import six
+from tensorflow.python.util.tf_export import tf_export
+
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+@tf_export('keras.utils.HDF5Matrix')
+class HDF5Matrix(object):
+ """Representation of HDF5 dataset to be used instead of a Numpy array.
+
+ Example:
+
+ ```python
+ x_data = HDF5Matrix('input/file.hdf5', 'data')
+ model.predict(x_data)
+ ```
+
+ Providing `start` and `end` allows use of a slice of the dataset.
+
+ Optionally, a normalizer function (or lambda) can be given. This will
+ be called on every slice of data retrieved.
+
+ Arguments:
+ datapath: string, path to a HDF5 file
+ dataset: string, name of the HDF5 dataset in the file specified
+ in datapath
+ start: int, start of desired slice of the specified dataset
+ end: int, end of desired slice of the specified dataset
+ normalizer: function to be called on data when retrieved
+
+ Returns:
+ An array-like HDF5 dataset.
+ """
+ refs = defaultdict(int)
+
+ def __init__(self, datapath, dataset, start=0, end=None, normalizer=None):
+ if h5py is None:
+ raise ImportError('The use of HDF5Matrix requires '
+ 'HDF5 and h5py installed.')
+
+ if datapath not in list(self.refs.keys()):
+ f = h5py.File(datapath)
+ self.refs[datapath] = f
+ else:
+ f = self.refs[datapath]
+ self.data = f[dataset]
+ self.start = start
+ if end is None:
+ self.end = self.data.shape[0]
+ else:
+ self.end = end
+ self.normalizer = normalizer
+
+ def __len__(self):
+ return self.end - self.start
+
+ def __getitem__(self, key):
+ if isinstance(key, slice):
+ start, stop = key.start, key.stop
+ if start is None:
+ start = 0
+ if stop is None:
+ stop = self.shape[0]
+ if stop + self.start <= self.end:
+ idx = slice(start + self.start, stop + self.start)
+ else:
+ raise IndexError
+ elif isinstance(key, (int, np.integer)):
+ if key + self.start < self.end:
+ idx = key + self.start
+ else:
+ raise IndexError
+ elif isinstance(key, np.ndarray):
+ if np.max(key) + self.start < self.end:
+ idx = (self.start + key).tolist()
+ else:
+ raise IndexError
+ elif isinstance(key, list):
+ if max(key) + self.start < self.end:
+ idx = [x + self.start for x in key]
+ else:
+ raise IndexError
+ else:
+ raise IndexError
+ if self.normalizer is not None:
+ return self.normalizer(self.data[idx])
+ else:
+ return self.data[idx]
+
+ @property
+ def shape(self):
+ """Gets a numpy-style shape tuple giving the dataset dimensions.
+
+ Returns:
+ A numpy-style shape tuple.
+ """
+ return (self.end - self.start,) + self.data.shape[1:]
+
+ @property
+ def dtype(self):
+ """Gets the datatype of the dataset.
+
+ Returns:
+ A numpy dtype string.
+ """
+ return self.data.dtype
+
+ @property
+ def ndim(self):
+ """Gets the number of dimensions (rank) of the dataset.
+
+ Returns:
+ An integer denoting the number of dimensions (rank) of the dataset.
+ """
+ return self.data.ndim
+
+ @property
+ def size(self):
+ """Gets the total dataset size (number of elements).
+
+ Returns:
+ An integer denoting the number of elements in the dataset.
+ """
+ return np.prod(self.shape)
+
+
+def ask_to_proceed_with_overwrite(filepath):
+ """Produces a prompt asking about overwriting a file.
+
+ Arguments:
+ filepath: the path to the file to be overwritten.
+
+ Returns:
+ True if we can proceed with overwrite, False otherwise.
+ """
+ overwrite = six.moves.input('[WARNING] %s already exists - overwrite? '
+ '[y/n]' % (filepath)).strip().lower()
+ while overwrite not in ('y', 'n'):
+ overwrite = six.moves.input('Enter "y" (overwrite) or "n" '
+ '(cancel).').strip().lower()
+ if overwrite == 'n':
+ return False
+ print('[TIP] Next time specify overwrite=True!')
+ return True
diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py
new file mode 100644
index 0000000000..3895dca68e
--- /dev/null
+++ b/tensorflow/python/keras/utils/io_utils_test.py
@@ -0,0 +1,100 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for io_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.platform import test
+
+try:
+ import h5py # pylint:disable=g-import-not-at-top
+except ImportError:
+ h5py = None
+
+
+def create_dataset(h5_path='test.h5'):
+ x = np.random.randn(200, 10).astype('float32')
+ y = np.random.randint(0, 2, size=(200, 1))
+ f = h5py.File(h5_path, 'w')
+ # Creating dataset to store features
+ x_dset = f.create_dataset('my_data', (200, 10), dtype='f')
+ x_dset[:] = x
+ # Creating dataset to store labels
+ y_dset = f.create_dataset('my_labels', (200, 1), dtype='i')
+ y_dset[:] = y
+ f.close()
+
+
+class TestIOUtils(test.TestCase):
+
+ def test_HDF5Matrix(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ h5_path = os.path.join(temp_dir, 'test.h5')
+ create_dataset(h5_path)
+
+ # Instantiating HDF5Matrix for the training set,
+ # which is a slice of the first 150 elements
+ x_train = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_data', start=0, end=150)
+ y_train = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_labels', start=0, end=150)
+
+ # Likewise for the test set
+ x_test = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_data', start=150, end=200)
+ y_test = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_labels', start=150, end=200)
+
+ # HDF5Matrix behave more or less like Numpy matrices
+ # with regard to indexing
+ self.assertEqual(y_train.shape, (150, 1))
+ # But they do not support negative indices, so don't try print(x_train[-1])
+
+ self.assertEqual(y_train.dtype, np.dtype('i'))
+ self.assertEqual(y_train.ndim, 2)
+ self.assertEqual(y_train.size, 150)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
+ model.compile(loss='binary_crossentropy', optimizer='sgd')
+
+ # Note: you have to use shuffle='batch' or False with HDF5Matrix
+ model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
+ # test that evalutation and prediction
+ # don't crash and return reasonable results
+ out_pred = model.predict(x_test, batch_size=32, verbose=False)
+ out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
+
+ self.assertEqual(out_pred.shape, (50, 1))
+ self.assertEqual(out_eval.shape, ())
+ self.assertGreater(out_eval, 0)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
new file mode 100644
index 0000000000..bd61f8e9cc
--- /dev/null
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -0,0 +1,266 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Utilities related to layer/model functionality.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.utils.conv_utils import convert_kernel
+from tensorflow.python.util.tf_export import tf_export
+
+
+def count_params(weights):
+ """Count the total number of scalars composing the weights.
+
+ Arguments:
+ weights: An iterable containing the weights on which to compute params
+
+ Returns:
+ The total number of scalars composing the weights
+ """
+ return int(np.sum([np.prod(p.get_shape().as_list()) for p in set(weights)]))
+
+
+def print_summary(model, line_length=None, positions=None, print_fn=None):
+ """Prints a summary of a model.
+
+ Arguments:
+ model: Keras model instance.
+ line_length: Total length of printed lines
+ (e.g. set this to adapt the display to different
+ terminal window sizes).
+ positions: Relative or absolute positions of log elements in each line.
+ If not provided, defaults to `[.33, .55, .67, 1.]`.
+ print_fn: Print function to use.
+ It will be called on each line of the summary.
+ You can set it to a custom function
+ in order to capture the string summary.
+ It defaults to `print` (prints to stdout).
+ """
+ if print_fn is None:
+ print_fn = print
+
+ if model.__class__.__name__ == 'Sequential':
+ sequential_like = True
+ elif not model._is_graph_network:
+ # We treat subclassed models as a simple sequence of layers, for logging
+ # purposes.
+ sequential_like = True
+ else:
+ sequential_like = True
+ nodes_by_depth = model._nodes_by_depth.values()
+ nodes = []
+ for v in nodes_by_depth:
+ if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1):
+ # if the model has multiple nodes
+ # or if the nodes have multiple inbound_layers
+ # the model is no longer sequential
+ sequential_like = False
+ break
+ nodes += v
+ if sequential_like:
+ # search for shared layers
+ for layer in model.layers:
+ flag = False
+ for node in layer._inbound_nodes:
+ if node in nodes:
+ if flag:
+ sequential_like = False
+ break
+ else:
+ flag = True
+ if not sequential_like:
+ break
+
+ if sequential_like:
+ line_length = line_length or 65
+ positions = positions or [.45, .85, 1.]
+ if positions[-1] <= 1:
+ positions = [int(line_length * p) for p in positions]
+ # header names for the different log elements
+ to_display = ['Layer (type)', 'Output Shape', 'Param #']
+ else:
+ line_length = line_length or 98
+ positions = positions or [.33, .55, .67, 1.]
+ if positions[-1] <= 1:
+ positions = [int(line_length * p) for p in positions]
+ # header names for the different log elements
+ to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
+ relevant_nodes = []
+ for v in model._nodes_by_depth.values():
+ relevant_nodes += v
+
+ def print_row(fields, positions):
+ line = ''
+ for i in range(len(fields)):
+ if i > 0:
+ line = line[:-1] + ' '
+ line += str(fields[i])
+ line = line[:positions[i]]
+ line += ' ' * (positions[i] - len(line))
+ print_fn(line)
+
+ print_fn('_' * line_length)
+ print_row(to_display, positions)
+ print_fn('=' * line_length)
+
+ def print_layer_summary(layer):
+ """Prints a summary for a single layer.
+
+ Arguments:
+ layer: target layer.
+ """
+ try:
+ output_shape = layer.output_shape
+ except AttributeError:
+ output_shape = 'multiple'
+ except RuntimeError: # output_shape unknown in Eager mode.
+ output_shape = '?'
+ name = layer.name
+ cls_name = layer.__class__.__name__
+ fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
+ print_row(fields, positions)
+
+ def print_layer_summary_with_connections(layer):
+ """Prints a summary for a single layer (including topological connections).
+
+ Arguments:
+ layer: target layer.
+ """
+ try:
+ output_shape = layer.output_shape
+ except AttributeError:
+ output_shape = 'multiple'
+ connections = []
+ for node in layer._inbound_nodes:
+ if relevant_nodes and node not in relevant_nodes:
+ # node is not part of the current network
+ continue
+ for i in range(len(node.inbound_layers)):
+ inbound_layer = node.inbound_layers[i].name
+ inbound_node_index = node.node_indices[i]
+ inbound_tensor_index = node.tensor_indices[i]
+ connections.append(inbound_layer + '[' + str(inbound_node_index) +
+ '][' + str(inbound_tensor_index) + ']')
+
+ name = layer.name
+ cls_name = layer.__class__.__name__
+ if not connections:
+ first_connection = ''
+ else:
+ first_connection = connections[0]
+ fields = [
+ name + ' (' + cls_name + ')', output_shape,
+ layer.count_params(), first_connection
+ ]
+ print_row(fields, positions)
+ if len(connections) > 1:
+ for i in range(1, len(connections)):
+ fields = ['', '', '', connections[i]]
+ print_row(fields, positions)
+
+ layers = model.layers
+ for i in range(len(layers)):
+ if sequential_like:
+ print_layer_summary(layers[i])
+ else:
+ print_layer_summary_with_connections(layers[i])
+ if i == len(layers) - 1:
+ print_fn('=' * line_length)
+ else:
+ print_fn('_' * line_length)
+
+ model._check_trainable_weights_consistency()
+ if hasattr(model, '_collected_trainable_weights'):
+ trainable_count = count_params(model._collected_trainable_weights)
+ else:
+ trainable_count = count_params(model.trainable_weights)
+
+ non_trainable_count = count_params(model.non_trainable_weights)
+
+ print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
+ print_fn('Trainable params: {:,}'.format(trainable_count))
+ print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
+ print_fn('_' * line_length)
+
+
+@tf_export('keras.utils.convert_all_kernels_in_model')
+def convert_all_kernels_in_model(model):
+ """Converts all convolution kernels in a model from Theano to TensorFlow.
+
+ Also works from TensorFlow to Theano.
+
+ Arguments:
+ model: target model for the conversion.
+ """
+ # Note: SeparableConvolution not included
+ # since only supported by TF.
+ conv_classes = {
+ 'Conv1D',
+ 'Conv2D',
+ 'Conv3D',
+ 'Conv2DTranspose',
+ }
+ to_assign = []
+ for layer in model.layers:
+ if layer.__class__.__name__ in conv_classes:
+ original_kernel = K.get_value(layer.kernel)
+ converted_kernel = convert_kernel(original_kernel)
+ to_assign.append((layer.kernel, converted_kernel))
+ K.batch_set_value(to_assign)
+
+
+def convert_dense_weights_data_format(dense,
+ previous_feature_map_shape,
+ target_data_format='channels_first'):
+ """Utility useful when changing a convnet's `data_format`.
+
+ When porting the weights of a convnet from one data format to the other,
+ if the convnet includes a `Flatten` layer
+ (applied to the last convolutional feature map)
+ followed by a `Dense` layer, the weights of that `Dense` layer
+ should be updated to reflect the new dimension ordering.
+
+ Arguments:
+ dense: The target `Dense` layer.
+ previous_feature_map_shape: A shape tuple of 3 integers,
+ e.g. `(512, 7, 7)`. The shape of the convolutional
+ feature map right before the `Flatten` layer that
+ came before the target `Dense` layer.
+ target_data_format: One of "channels_last", "channels_first".
+ Set it "channels_last"
+ if converting a "channels_first" model to "channels_last",
+ or reciprocally.
+ """
+ assert target_data_format in {'channels_last', 'channels_first'}
+ kernel, bias = dense.get_weights()
+ for i in range(kernel.shape[1]):
+ if target_data_format == 'channels_first':
+ c, h, w = previous_feature_map_shape
+ original_fm_shape = (h, w, c)
+ ki = kernel[:, i].reshape(original_fm_shape)
+ ki = np.transpose(ki, (2, 0, 1)) # last -> first
+ else:
+ h, w, c = previous_feature_map_shape
+ original_fm_shape = (c, h, w)
+ ki = kernel[:, i].reshape(original_fm_shape)
+ ki = np.transpose(ki, (1, 2, 0)) # first -> last
+ kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
+ dense.set_weights([kernel, bias])
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
new file mode 100644
index 0000000000..e5442f04e3
--- /dev/null
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -0,0 +1,252 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for multi-gpu training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.ops import array_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def _get_available_devices():
+ return [x.name for x in K.get_session().list_devices()]
+
+
+def _normalize_device_name(name):
+ name = '/' + name.lower().split('device:')[1]
+ return name
+
+
+@tf_export('keras.utils.multi_gpu_model')
+def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
+ """Replicates a model on different GPUs.
+
+ Specifically, this function implements single-machine
+ multi-GPU data parallelism. It works in the following way:
+
+ - Divide the model's input(s) into multiple sub-batches.
+ - Apply a model copy on each sub-batch. Every model copy
+ is executed on a dedicated GPU.
+ - Concatenate the results (on CPU) into one big batch.
+
+ E.g. if your `batch_size` is 64 and you use `gpus=2`,
+ then we will divide the input into 2 sub-batches of 32 samples,
+ process each sub-batch on one GPU, then return the full
+ batch of 64 processed samples.
+
+ This induces quasi-linear speedup on up to 8 GPUs.
+
+ This function is only available with the TensorFlow backend
+ for the time being.
+
+ Arguments:
+ model: A Keras model instance. To avoid OOM errors,
+ this model could have been built on CPU, for instance
+ (see usage example below).
+ gpus: Integer >= 2, number of on GPUs on which to create
+ model replicas.
+ cpu_merge: A boolean value to identify whether to force
+ merging model weights under the scope of the CPU or not.
+ cpu_relocation: A boolean value to identify whether to
+ create the model's weights under the scope of the CPU.
+ If the model is not defined under any preceding device
+ scope, you can still rescue it by activating this option.
+
+ Returns:
+ A Keras `Model` instance which can be used just like the initial
+ `model` argument, but which distributes its workload on multiple GPUs.
+
+ Example 1: Training models with weights merge on CPU
+
+ ```python
+ import tensorflow as tf
+ from keras.applications import Xception
+ from keras.utils import multi_gpu_model
+ import numpy as np
+
+ num_samples = 1000
+ height = 224
+ width = 224
+ num_classes = 1000
+
+ # Instantiate the base model (or "template" model).
+ # We recommend doing this with under a CPU device scope,
+ # so that the model's weights are hosted on CPU memory.
+ # Otherwise they may end up hosted on a GPU, which would
+ # complicate weight sharing.
+ with tf.device('/cpu:0'):
+ model = Xception(weights=None,
+ input_shape=(height, width, 3),
+ classes=num_classes)
+
+ # Replicates the model on 8 GPUs.
+ # This assumes that your machine has 8 available GPUs.
+ parallel_model = multi_gpu_model(model, gpus=8)
+ parallel_model.compile(loss='categorical_crossentropy',
+ optimizer='rmsprop')
+
+ # Generate dummy data.
+ x = np.random.random((num_samples, height, width, 3))
+ y = np.random.random((num_samples, num_classes))
+
+ # This `fit` call will be distributed on 8 GPUs.
+ # Since the batch size is 256, each GPU will process 32 samples.
+ parallel_model.fit(x, y, epochs=20, batch_size=256)
+
+ # Save model via the template model (which shares the same weights):
+ model.save('my_model.h5')
+ ```
+
+ Example 2: Training models with weights merge on CPU using cpu_relocation
+
+ ```python
+ ..
+ # Not needed to change the device scope for model definition:
+ model = Xception(weights=None, ..)
+
+ try:
+ model = multi_gpu_model(model, cpu_relocation=True)
+ print("Training using multiple GPUs..")
+ except:
+ print("Training using single GPU or CPU..")
+
+ model.compile(..)
+ ..
+ ```
+
+ Example 3: Training models with weights merge on GPU (recommended for NV-link)
+
+ ```python
+ ..
+ # Not needed to change the device scope for model definition:
+ model = Xception(weights=None, ..)
+
+ try:
+ model = multi_gpu_model(model, cpu_merge=False)
+ print("Training using multiple GPUs..")
+ except:
+ print("Training using single GPU or CPU..")
+ model.compile(..)
+ ..
+ ```
+
+ Raises:
+ ValueError: if the `gpus` argument does not match available devices.
+ """
+ # pylint: disable=g-import-not-at-top
+ from tensorflow.python.keras.layers.core import Lambda
+ from tensorflow.python.keras.layers.merge import concatenate
+
+ if isinstance(gpus, (list, tuple)):
+ if len(gpus) <= 1:
+ raise ValueError('For multi-gpu usage to be effective, '
+ 'call `multi_gpu_model` with `len(gpus) >= 2`. '
+ 'Received: `gpus=%s`' % gpus)
+ num_gpus = len(gpus)
+ target_gpu_ids = gpus
+ else:
+ if gpus <= 1:
+ raise ValueError('For multi-gpu usage to be effective, '
+ 'call `multi_gpu_model` with `gpus >= 2`. '
+ 'Received: `gpus=%s`' % gpus)
+ num_gpus = gpus
+ target_gpu_ids = range(num_gpus)
+
+ target_devices = ['/cpu:0'] + ['/gpu:%d' % i for i in target_gpu_ids]
+ available_devices = _get_available_devices()
+ available_devices = [
+ _normalize_device_name(name) for name in available_devices
+ ]
+ for device in target_devices:
+ if device not in available_devices:
+ raise ValueError('To call `multi_gpu_model` with `gpus=%s`, '
+ 'we expect the following devices to be available: %s. '
+ 'However this machine only has: %s. '
+ 'Try reducing `gpus`.' % (gpus, target_devices,
+ available_devices))
+
+ def get_slice(data, i, parts):
+ """Slice an array into `parts` slices and return slice `i`.
+
+ Arguments:
+ data: array to slice.
+ i: index of slice to return.
+ parts: number of slices to make.
+
+ Returns:
+ Slice `i` of `data`.
+ """
+ shape = array_ops.shape(data)
+ batch_size = shape[:1]
+ input_shape = shape[1:]
+ step = batch_size // parts
+ if i == num_gpus - 1:
+ size = batch_size - step * i
+ else:
+ size = step
+ size = array_ops.concat([size, input_shape], axis=0)
+ stride = array_ops.concat([step, input_shape * 0], axis=0)
+ start = stride * i
+ return array_ops.slice(data, start, size)
+
+ # Relocate the model definition under CPU device scope if needed
+ if cpu_relocation:
+ from tensorflow.python.keras.models import clone_model # pylint: disable=g-import-not-at-top
+ with ops.device('/cpu:0'):
+ model = clone_model(model)
+
+ all_outputs = []
+ for i in range(len(model.outputs)):
+ all_outputs.append([])
+
+ # Place a copy of the model on each GPU,
+ # each getting a slice of the inputs.
+ for i, gpu_id in enumerate(target_gpu_ids):
+ with ops.device('/gpu:%d' % gpu_id):
+ with ops.name_scope('replica_%d' % gpu_id):
+ inputs = []
+ # Retrieve a slice of the input.
+ for x in model.inputs:
+ input_shape = tuple(x.get_shape().as_list())[1:]
+ slice_i = Lambda(
+ get_slice,
+ output_shape=input_shape,
+ arguments={
+ 'i': i,
+ 'parts': num_gpus
+ })(
+ x)
+ inputs.append(slice_i)
+
+ # Apply model on slice
+ # (creating a model replica on the target device).
+ outputs = model(inputs)
+ if not isinstance(outputs, list):
+ outputs = [outputs]
+
+ # Save the outputs for merging back together later.
+ for o in range(len(outputs)):
+ all_outputs[o].append(outputs[o])
+
+ # Merge outputs under expected scope.
+ with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
+ merged = []
+ for name, outputs in zip(model.output_names, all_outputs):
+ merged.append(concatenate(outputs, axis=0, name=name))
+ return Model(model.inputs, merged)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
new file mode 100644
index 0000000000..77792d14f5
--- /dev/null
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -0,0 +1,185 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multi-gpu training utilities."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python import data
+from tensorflow.python import keras
+from tensorflow.python.platform import test
+
+
+def check_if_compatible_devices(gpus=2):
+ available_devices = [
+ keras.utils.multi_gpu_utils._normalize_device_name(name)
+ for name in keras.utils.multi_gpu_utils._get_available_devices()
+ ]
+ if '/gpu:%d' % (gpus - 1) not in available_devices:
+ return False
+ return True
+
+
+class TestMultiGPUModel(test.TestCase):
+
+ def test_multi_gpu_test_simple_model(self):
+ gpus = 2
+ num_samples = 1000
+ input_dim = 10
+ output_dim = 1
+ hidden_dim = 10
+ epochs = 2
+ target_gpu_id = [0, 1]
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(hidden_dim,
+ input_shape=(input_dim,)))
+ model.add(keras.layers.Dense(output_dim))
+
+ x = np.random.random((num_samples, input_dim))
+ y = np.random.random((num_samples, output_dim))
+
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus)
+ parallel_model.compile(loss='mse', optimizer='rmsprop')
+ parallel_model.fit(x, y, epochs=epochs)
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=target_gpu_id)
+ parallel_model.compile(loss='mse', optimizer='rmsprop')
+ parallel_model.fit(x, y, epochs=epochs)
+
+ def test_multi_gpu_test_multi_io_model(self):
+ gpus = 2
+ num_samples = 1000
+ input_dim_a = 10
+ input_dim_b = 5
+ output_dim_a = 1
+ output_dim_b = 2
+ hidden_dim = 10
+ epochs = 2
+ target_gpu_id = [0, 1]
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.test_session():
+ input_a = keras.Input((input_dim_a,))
+ input_b = keras.Input((input_dim_b,))
+ a = keras.layers.Dense(hidden_dim)(input_a)
+ b = keras.layers.Dense(hidden_dim)(input_b)
+ c = keras.layers.concatenate([a, b])
+ output_a = keras.layers.Dense(output_dim_a)(c)
+ output_b = keras.layers.Dense(output_dim_b)(c)
+ model = keras.models.Model([input_a, input_b], [output_a, output_b])
+
+ a_x = np.random.random((num_samples, input_dim_a))
+ b_x = np.random.random((num_samples, input_dim_b))
+ a_y = np.random.random((num_samples, output_dim_a))
+ b_y = np.random.random((num_samples, output_dim_b))
+
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus)
+ parallel_model.compile(loss='mse', optimizer='rmsprop')
+ parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs)
+
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=target_gpu_id)
+ parallel_model.compile(loss='mse', optimizer='rmsprop')
+ parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs)
+
+ def test_multi_gpu_test_invalid_devices(self):
+ if not check_if_compatible_devices(gpus=2):
+ return
+
+ with self.test_session():
+ input_shape = (1000, 10)
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10,
+ activation='relu',
+ input_shape=input_shape[1:]))
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
+ model.compile(loss='mse', optimizer='rmsprop')
+
+ x = np.random.random(input_shape)
+ y = np.random.random((input_shape[0], 1))
+ with self.assertRaises(ValueError):
+ parallel_model = keras.utils.multi_gpu_model(
+ model, gpus=len(keras.backend._get_available_gpus()) + 1)
+ parallel_model.fit(x, y, epochs=2)
+
+ with self.assertRaises(ValueError):
+ parallel_model = keras.utils.multi_gpu_model(
+ model, gpus=[0, 2, 4, 6, 8])
+ parallel_model.fit(x, y, epochs=2)
+
+ with self.assertRaises(ValueError):
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=1)
+ parallel_model.fit(x, y, epochs=2)
+
+ with self.assertRaises(ValueError):
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=[0])
+ parallel_model.fit(x, y, epochs=2)
+
+ def test_nested_model_with_tensor_input(self):
+ gpus = 2
+ input_dim = 10
+ shape = (input_dim,)
+ num_samples = 16
+ num_classes = 10
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.test_session():
+ input_shape = (num_samples,) + shape
+ x_train = np.random.randint(0, 255, input_shape)
+ y_train = np.random.randint(0, num_classes, (input_shape[0],))
+ keras.backend.set_learning_phase(True)
+
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+
+ x_train = x_train.astype('float32')
+ y_train = y_train.astype('float32')
+
+ dataset = data.Dataset.from_tensor_slices((x_train, y_train))
+ dataset = dataset.repeat()
+ dataset = dataset.batch(4)
+ iterator = dataset.make_one_shot_iterator()
+
+ inputs, targets = iterator.get_next()
+
+ input_tensor = keras.layers.Input(tensor=inputs)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(3,
+ input_shape=(input_dim,)))
+ model.add(keras.layers.Dense(num_classes))
+
+ output = model(input_tensor)
+ outer_model = keras.Model(input_tensor, output)
+ parallel_model = keras.utils.multi_gpu_model(outer_model, gpus=gpus)
+
+ parallel_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=keras.optimizers.RMSprop(lr=0.0001, decay=1e-6),
+ metrics=['accuracy'],
+ target_tensors=[targets])
+ parallel_model.fit(epochs=1, steps_per_epoch=3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
new file mode 100644
index 0000000000..9d9c72b162
--- /dev/null
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -0,0 +1,67 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Numpy-related utilities."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('keras.utils.to_categorical')
+def to_categorical(y, num_classes=None):
+ """Converts a class vector (integers) to binary class matrix.
+
+ E.g. for use with categorical_crossentropy.
+
+ Arguments:
+ y: class vector to be converted into a matrix
+ (integers from 0 to num_classes).
+ num_classes: total number of classes.
+
+ Returns:
+ A binary matrix representation of the input.
+ """
+ y = np.array(y, dtype='int')
+ input_shape = y.shape
+ if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
+ input_shape = tuple(input_shape[:-1])
+ y = y.ravel()
+ if not num_classes:
+ num_classes = np.max(y) + 1
+ n = y.shape[0]
+ categorical = np.zeros((n, num_classes), dtype=np.float32)
+ categorical[np.arange(n), y] = 1
+ output_shape = input_shape + (num_classes,)
+ categorical = np.reshape(categorical, output_shape)
+ return categorical
+
+
+@tf_export('keras.utils.normalize')
+def normalize(x, axis=-1, order=2):
+ """Normalizes a Numpy array.
+
+ Arguments:
+ x: Numpy array to normalize.
+ axis: axis along which to normalize.
+ order: Normalization order (e.g. 2 for L2 norm).
+
+ Returns:
+ A normalized copy of the array.
+ """
+ l2 = np.atleast_1d(np.linalg.norm(x, order, axis))
+ l2[l2 == 0] = 1
+ return x / np.expand_dims(l2, axis)
diff --git a/tensorflow/python/keras/utils/np_utils_test.py b/tensorflow/python/keras/utils/np_utils_test.py
new file mode 100644
index 0000000000..d77e76ff3e
--- /dev/null
+++ b/tensorflow/python/keras/utils/np_utils_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for np_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.platform import test
+
+
+class TestNPUtils(test.TestCase):
+
+ def test_to_categorical(self):
+ num_classes = 5
+ shapes = [(1,), (3,), (4, 3), (5, 4, 3), (3, 1), (3, 2, 1)]
+ expected_shapes = [(1, num_classes),
+ (3, num_classes),
+ (4, 3, num_classes),
+ (5, 4, 3, num_classes),
+ (3, num_classes)]
+ labels = [np.random.randint(0, num_classes, shape) for shape in shapes]
+ one_hots = [
+ keras.utils.to_categorical(label, num_classes) for label in labels]
+ for label, one_hot, expected_shape in zip(labels,
+ one_hots,
+ expected_shapes):
+ # Check shape
+ self.assertEqual(one_hot.shape, expected_shape)
+ # Make sure there is only one 1 in a row
+ self.assertTrue(np.all(one_hot.sum(axis=-1) == 1))
+ # Get original labels back from one hots
+ self.assertTrue(np.all(
+ np.argmax(one_hot, -1).reshape(label.shape) == label))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py
new file mode 100644
index 0000000000..162e5b2cd6
--- /dev/null
+++ b/tensorflow/python/keras/utils/tf_utils.py
@@ -0,0 +1,154 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow-related utilities."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond as smart_module
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.util import nest
+
+
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
+
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
+
+ Arguments:
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
+ name: Optional name prefix when using `tf.cond`.
+
+ Returns:
+ Tensors returned by the call to either `true_fn` or `false_fn`.
+
+ Raises:
+ TypeError: If `true_fn` or `false_fn` is not callable.
+ """
+ if isinstance(pred, variables.Variable):
+ return control_flow_ops.cond(
+ pred, true_fn=true_fn, false_fn=false_fn, name=name)
+ return smart_module.smart_cond(
+ pred, true_fn=true_fn, false_fn=false_fn, name=name)
+
+
+def constant_value(pred):
+ """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+ Arguments:
+ pred: A scalar, either a Python bool or a TensorFlow boolean variable
+ or tensor, or the Python integer 1 or 0.
+
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
+
+ Raises:
+ TypeError: If `pred` is not a Variable, Tensor or bool, or Python
+ integer 1 or 0.
+ """
+ # Allow integer booleans.
+ if isinstance(pred, int):
+ if pred == 1:
+ pred = True
+ elif pred == 0:
+ pred = False
+
+ if isinstance(pred, variables.Variable):
+ return None
+ return smart_module.smart_constant_value(pred)
+
+
+def is_tensor_or_tensor_list(v):
+ v = nest.flatten(v)
+ if v and isinstance(v[0], ops.Tensor):
+ return True
+ else:
+ return False
+
+
+def get_reachable_from_inputs(inputs, targets=None):
+ """Returns the set of tensors/ops reachable from `inputs`.
+
+ Stops if all targets have been found (target is optional).
+
+ Only valid in Symbolic mode, not Eager mode.
+
+ Args:
+ inputs: List of tensors.
+ targets: List of tensors.
+
+ Returns:
+ A set of tensors reachable from the inputs (includes the inputs themselves).
+ """
+ reachable = set(inputs)
+ if targets:
+ targets = set(targets)
+ queue = inputs[:]
+
+ while queue:
+ x = queue.pop()
+ if isinstance(x, ops.Operation):
+ outputs = x.outputs[:] or []
+ outputs += x._control_outputs # pylint: disable=protected-access
+ elif isinstance(x, ops.Tensor):
+ outputs = x.consumers()
+ elif isinstance(x, variables.Variable):
+ outputs = [x.op]
+ else:
+ raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
+
+ for y in outputs:
+ if y not in reachable:
+ reachable.add(y)
+ queue.insert(0, y)
+
+ if targets and targets.issubset(reachable):
+ return reachable
+ return reachable
+
+
+def shape_type_conversion(fn):
+ """Decorator that handles tuple/TensorShape conversion.
+
+ Used in `compute_output_shape` and `build`.
+
+ Arguments:
+ fn: function to wrap.
+
+ Returns:
+ Wrapped function.
+ """
+
+ def wrapper(instance, input_shape):
+ if input_shape is not None:
+ if isinstance(input_shape, list):
+ input_shape = [
+ tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
+ else:
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ output_shape = fn(instance, input_shape)
+ if output_shape is not None:
+ if isinstance(output_shape, list):
+ return [tensor_shape.TensorShape(x) for x in output_shape]
+ return tensor_shape.TensorShape(output_shape)
+
+ return wrapper
diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py
new file mode 100644
index 0000000000..8007df4622
--- /dev/null
+++ b/tensorflow/python/keras/utils/vis_utils.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+# pylint: disable=g-import-not-at-top
+"""Utilities related to model visualization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from tensorflow.python.util.tf_export import tf_export
+
+
+try:
+ # pydot-ng is a fork of pydot that is better maintained.
+ import pydot_ng as pydot
+except ImportError:
+ # pydotplus is an improved version of pydot
+ try:
+ import pydotplus as pydot
+ except ImportError:
+ # Fall back on pydot if necessary.
+ try:
+ import pydot
+ except ImportError:
+ pydot = None
+
+
+def _check_pydot():
+ try:
+ # Attempt to create an image of a blank graph
+ # to check the pydot/graphviz installation.
+ pydot.Dot.create(pydot.Dot())
+ except Exception:
+ # pydot raises a generic Exception here,
+ # so no specific class can be caught.
+ raise ImportError('Failed to import pydot. You must install pydot'
+ ' and graphviz for `pydotprint` to work.')
+
+
+def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
+ """Convert a Keras model to dot format.
+
+ Arguments:
+ model: A Keras model instance.
+ show_shapes: whether to display shape information.
+ show_layer_names: whether to display layer names.
+ rankdir: `rankdir` argument passed to PyDot,
+ a string specifying the format of the plot:
+ 'TB' creates a vertical plot;
+ 'LR' creates a horizontal plot.
+
+ Returns:
+ A `pydot.Dot` instance representing the Keras model.
+ """
+ from tensorflow.python.keras.layers.wrappers import Wrapper
+ from tensorflow.python.keras.models import Sequential
+
+ _check_pydot()
+ dot = pydot.Dot()
+ dot.set('rankdir', rankdir)
+ dot.set('concentrate', True)
+ dot.set_node_defaults(shape='record')
+
+ if isinstance(model, Sequential):
+ if not model.built:
+ model.build()
+ model = model.model
+ layers = model.layers
+
+ # Create graph nodes.
+ for layer in layers:
+ layer_id = str(id(layer))
+
+ # Append a wrapped layer's label to node's label, if it exists.
+ layer_name = layer.name
+ class_name = layer.__class__.__name__
+ if isinstance(layer, Wrapper):
+ layer_name = '{}({})'.format(layer_name, layer.layer.name)
+ child_class_name = layer.layer.__class__.__name__
+ class_name = '{}({})'.format(class_name, child_class_name)
+
+ # Create node's label.
+ if show_layer_names:
+ label = '{}: {}'.format(layer_name, class_name)
+ else:
+ label = class_name
+
+ # Rebuild the label as a table including input/output shapes.
+ if show_shapes:
+ try:
+ outputlabels = str(layer.output_shape)
+ except AttributeError:
+ outputlabels = 'multiple'
+ if hasattr(layer, 'input_shape'):
+ inputlabels = str(layer.input_shape)
+ elif hasattr(layer, 'input_shapes'):
+ inputlabels = ', '.join([str(ishape) for ishape in layer.input_shapes])
+ else:
+ inputlabels = 'multiple'
+ label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
+ outputlabels)
+ node = pydot.Node(layer_id, label=label)
+ dot.add_node(node)
+
+ # Connect nodes with edges.
+ for layer in layers:
+ layer_id = str(id(layer))
+ for i, node in enumerate(layer._inbound_nodes):
+ node_key = layer.name + '_ib-' + str(i)
+ if node_key in model._network_nodes: # pylint: disable=protected-access
+ for inbound_layer in node.inbound_layers:
+ inbound_layer_id = str(id(inbound_layer))
+ layer_id = str(id(layer))
+ dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
+ return dot
+
+
+@tf_export('keras.utils.plot_model')
+def plot_model(model,
+ to_file='model.png',
+ show_shapes=False,
+ show_layer_names=True,
+ rankdir='TB'):
+ """Converts a Keras model to dot format and save to a file.
+
+ Arguments:
+ model: A Keras model instance
+ to_file: File name of the plot image.
+ show_shapes: whether to display shape information.
+ show_layer_names: whether to display layer names.
+ rankdir: `rankdir` argument passed to PyDot,
+ a string specifying the format of the plot:
+ 'TB' creates a vertical plot;
+ 'LR' creates a horizontal plot.
+ """
+ dot = model_to_dot(model, show_shapes, show_layer_names, rankdir)
+ _, extension = os.path.splitext(to_file)
+ if not extension:
+ extension = 'png'
+ else:
+ extension = extension[1:]
+ dot.write(to_file, format=extension)