diff options
author | 2018-05-17 21:36:39 -0700 | |
---|---|---|
committer | 2018-05-17 21:40:10 -0700 | |
commit | 609b2ce3fe8ebecf4031670b8c2186468369b0ba (patch) | |
tree | 59d5eb7308ffc67a4602f9b028cdd45450f56777 /tensorflow/python/keras/utils | |
parent | aca0458707fa63626c78acfeae2ade9ee78c54d1 (diff) |
Move Keras code out of _impl folder and remove API files.
PiperOrigin-RevId: 197097430
Diffstat (limited to 'tensorflow/python/keras/utils')
-rw-r--r-- | tensorflow/python/keras/utils/__init__.py | 32 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/conv_utils.py | 201 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/data_utils.py | 824 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/data_utils_test.py | 311 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/generic_utils.py | 561 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/generic_utils_test.py | 75 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/io_utils.py | 171 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/io_utils_test.py | 100 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/layer_utils.py | 266 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/multi_gpu_utils.py | 252 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/multi_gpu_utils_test.py | 185 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/np_utils.py | 67 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/np_utils_test.py | 53 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/tf_utils.py | 154 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/vis_utils.py | 155 |
15 files changed, 3391 insertions, 16 deletions
diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py index 2f74cf031d..7b5eecc153 100644 --- a/tensorflow/python/keras/utils/__init__.py +++ b/tensorflow/python/keras/utils/__init__.py @@ -18,22 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py new file mode 100644 index 0000000000..5419e7ae05 --- /dev/null +++ b/tensorflow/python/keras/utils/conv_utils.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities used by convolution layers. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import range # pylint: disable=redefined-builtin + +from tensorflow.python.keras import backend + + +def convert_data_format(data_format, ndim): + if data_format == 'channels_last': + if ndim == 3: + return 'NWC' + elif ndim == 4: + return 'NHWC' + elif ndim == 5: + return 'NDHWC' + else: + raise ValueError('Input rank not supported:', ndim) + elif data_format == 'channels_first': + if ndim == 3: + return 'NCW' + elif ndim == 4: + return 'NCHW' + elif ndim == 5: + return 'NCDHW' + else: + raise ValueError('Input rank not supported:', ndim) + else: + raise ValueError('Invalid data_format:', data_format) + + +def normalize_tuple(value, n, name): + """Transforms a single integer or iterable of integers into an integer tuple. + + Arguments: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the tuple to be returned. + name: The name of the argument being validated, e.g. "strides" or + "kernel_size". This is only used to format error messages. + + Returns: + A tuple of n integers. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, int): + return (value,) * n + else: + try: + value_tuple = tuple(value) + except TypeError: + raise ValueError('The `' + name + '` argument must be a tuple of ' + + str(n) + ' integers. Received: ' + str(value)) + if len(value_tuple) != n: + raise ValueError('The `' + name + '` argument must be a tuple of ' + + str(n) + ' integers. Received: ' + str(value)) + for single_value in value_tuple: + try: + int(single_value) + except (ValueError, TypeError): + raise ValueError('The `' + name + '` argument must be a tuple of ' + + str(n) + ' integers. Received: ' + str(value) + ' ' + 'including element ' + str(single_value) + ' of type' + + ' ' + str(type(single_value))) + return value_tuple + + +def conv_output_length(input_length, filter_size, padding, stride, dilation=1): + """Determines output length of a convolution given input length. + + Arguments: + input_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full". + stride: integer. + dilation: dilation rate, integer. + + Returns: + The output length (integer). + """ + if input_length is None: + return None + assert padding in {'same', 'valid', 'full'} + dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) + if padding == 'same': + output_length = input_length + elif padding == 'valid': + output_length = input_length - dilated_filter_size + 1 + elif padding == 'full': + output_length = input_length + dilated_filter_size - 1 + return (output_length + stride - 1) // stride + + +def conv_input_length(output_length, filter_size, padding, stride): + """Determines input length of a convolution given output length. + + Arguments: + output_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full". + stride: integer. + + Returns: + The input length (integer). + """ + if output_length is None: + return None + assert padding in {'same', 'valid', 'full'} + if padding == 'same': + pad = filter_size // 2 + elif padding == 'valid': + pad = 0 + elif padding == 'full': + pad = filter_size - 1 + return (output_length - 1) * stride - 2 * pad + filter_size + + +def deconv_output_length(input_length, filter_size, padding, stride): + """Determines output length of a transposed convolution given input length. + + Arguments: + input_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full". + stride: integer. + + Returns: + The output length (integer). + """ + if input_length is None: + return None + input_length *= stride + if padding == 'valid': + input_length += max(filter_size - stride, 0) + elif padding == 'full': + input_length -= (stride + filter_size - 2) + return input_length + + +def normalize_data_format(value): + if value is None: + value = backend.image_data_format() + data_format = value.lower() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('The `data_format` argument must be one of ' + '"channels_first", "channels_last". Received: ' + + str(value)) + return data_format + + +def normalize_padding(value): + padding = value.lower() + if padding not in {'valid', 'same', 'causal'}: + raise ValueError('The `padding` argument must be one of ' + '"valid", "same" (or "causal", only for `Conv1D). ' + 'Received: ' + str(padding)) + return padding + + +def convert_kernel(kernel): + """Converts a Numpy kernel matrix from Theano format to TensorFlow format. + + Also works reciprocally, since the transformation is its own inverse. + + Arguments: + kernel: Numpy array (3D, 4D or 5D). + + Returns: + The converted kernel. + + Raises: + ValueError: in case of invalid kernel shape or invalid data_format. + """ + kernel = np.asarray(kernel) + if not 3 <= kernel.ndim <= 5: + raise ValueError('Invalid kernel shape:', kernel.shape) + slices = [slice(None, None, -1) for _ in range(kernel.ndim)] + no_flip = (slice(None, None), slice(None, None)) + slices[-2:] = no_flip + return np.copy(kernel[slices]) diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py new file mode 100644 index 0000000000..a1f89d9d43 --- /dev/null +++ b/tensorflow/python/keras/utils/data_utils.py @@ -0,0 +1,824 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-import-not-at-top +"""Utilities for file download and caching.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from abc import abstractmethod +from contextlib import closing +import hashlib +import multiprocessing +from multiprocessing.pool import ThreadPool +import os +import random +import shutil +import sys +import tarfile +import threading +import time +import traceback +import zipfile + +import numpy as np +import six +from six.moves.urllib.error import HTTPError +from six.moves.urllib.error import URLError +from six.moves.urllib.request import urlopen + +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.util.tf_export import tf_export + + +try: + import queue +except ImportError: + import Queue as queue + + +if sys.version_info[0] == 2: + + def urlretrieve(url, filename, reporthook=None, data=None): + """Replacement for `urlretrive` for Python 2. + + Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy + `urllib` module, known to have issues with proxy management. + + Arguments: + url: url to retrieve. + filename: where to store the retrieved data locally. + reporthook: a hook function that will be called once + on establishment of the network connection and once + after each block read thereafter. + The hook will be passed three arguments; + a count of blocks transferred so far, + a block size in bytes, and the total size of the file. + data: `data` argument passed to `urlopen`. + """ + + def chunk_read(response, chunk_size=8192, reporthook=None): + content_type = response.info().get('Content-Length') + total_size = -1 + if content_type is not None: + total_size = int(content_type.strip()) + count = 0 + while True: + chunk = response.read(chunk_size) + count += 1 + if reporthook is not None: + reporthook(count, chunk_size, total_size) + if chunk: + yield chunk + else: + break + + response = urlopen(url, data) + with open(filename, 'wb') as fd: + for chunk in chunk_read(response, reporthook=reporthook): + fd.write(chunk) +else: + from six.moves.urllib.request import urlretrieve + + +def _extract_archive(file_path, path='.', archive_format='auto'): + """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. + + Arguments: + file_path: path to the archive file + path: path to extract the archive file + archive_format: Archive format to try for extracting the file. + Options are 'auto', 'tar', 'zip', and None. + 'tar' includes tar, tar.gz, and tar.bz files. + The default 'auto' is ['tar', 'zip']. + None or an empty list will return no matches found. + + Returns: + True if a match was found and an archive extraction was completed, + False otherwise. + """ + if archive_format is None: + return False + if archive_format is 'auto': + archive_format = ['tar', 'zip'] + if isinstance(archive_format, six.string_types): + archive_format = [archive_format] + + for archive_type in archive_format: + if archive_type is 'tar': + open_fn = tarfile.open + is_match_fn = tarfile.is_tarfile + if archive_type is 'zip': + open_fn = zipfile.ZipFile + is_match_fn = zipfile.is_zipfile + + if is_match_fn(file_path): + with open_fn(file_path) as archive: + try: + archive.extractall(path) + except (tarfile.TarError, RuntimeError, KeyboardInterrupt): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + raise + return True + return False + + +@tf_export('keras.utils.get_file') +def get_file(fname, + origin, + untar=False, + md5_hash=None, + file_hash=None, + cache_subdir='datasets', + hash_algorithm='auto', + extract=False, + archive_format='auto', + cache_dir=None): + """Downloads a file from a URL if it not already in the cache. + + By default the file at the url `origin` is downloaded to the + cache_dir `~/.keras`, placed in the cache_subdir `datasets`, + and given the filename `fname`. The final location of a file + `example.txt` would therefore be `~/.keras/datasets/example.txt`. + + Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. + Passing a hash will verify the file after download. The command line + programs `shasum` and `sha256sum` can compute the hash. + + Arguments: + fname: Name of the file. If an absolute path `/path/to/file.txt` is + specified the file will be saved at that location. + origin: Original URL of the file. + untar: Deprecated in favor of 'extract'. + boolean, whether the file should be decompressed + md5_hash: Deprecated in favor of 'file_hash'. + md5 hash of the file for verification + file_hash: The expected hash string of the file after download. + The sha256 and md5 hash algorithms are both supported. + cache_subdir: Subdirectory under the Keras cache dir where the file is + saved. If an absolute path `/path/to/folder` is + specified the file will be saved at that location. + hash_algorithm: Select the hash algorithm to verify the file. + options are 'md5', 'sha256', and 'auto'. + The default 'auto' detects the hash algorithm in use. + extract: True tries extracting the file as an Archive, like tar or zip. + archive_format: Archive format to try for extracting the file. + Options are 'auto', 'tar', 'zip', and None. + 'tar' includes tar, tar.gz, and tar.bz files. + The default 'auto' is ['tar', 'zip']. + None or an empty list will return no matches found. + cache_dir: Location to store cached files, when None it + defaults to the [Keras + Directory](/faq/#where-is-the-keras-configuration-filed-stored). + + Returns: + Path to the downloaded file + """ + if cache_dir is None: + cache_dir = os.path.join(os.path.expanduser('~'), '.keras') + if md5_hash is not None and file_hash is None: + file_hash = md5_hash + hash_algorithm = 'md5' + datadir_base = os.path.expanduser(cache_dir) + if not os.access(datadir_base, os.W_OK): + datadir_base = os.path.join('/tmp', '.keras') + datadir = os.path.join(datadir_base, cache_subdir) + if not os.path.exists(datadir): + os.makedirs(datadir) + + if untar: + untar_fpath = os.path.join(datadir, fname) + fpath = untar_fpath + '.tar.gz' + else: + fpath = os.path.join(datadir, fname) + + download = False + if os.path.exists(fpath): + # File found; verify integrity if a hash was provided. + if file_hash is not None: + if not validate_file(fpath, file_hash, algorithm=hash_algorithm): + print('A local file was found, but it seems to be ' + 'incomplete or outdated because the ' + hash_algorithm + + ' file hash does not match the original value of ' + file_hash + + ' so we will re-download the data.') + download = True + else: + download = True + + if download: + print('Downloading data from', origin) + + class ProgressTracker(object): + # Maintain progbar for the lifetime of download. + # This design was chosen for Python 2.7 compatibility. + progbar = None + + def dl_progress(count, block_size, total_size): + if ProgressTracker.progbar is None: + if total_size is -1: + total_size = None + ProgressTracker.progbar = Progbar(total_size) + else: + ProgressTracker.progbar.update(count * block_size) + + error_msg = 'URL fetch failure on {}: {} -- {}' + try: + try: + urlretrieve(origin, fpath, dl_progress) + except URLError as e: + raise Exception(error_msg.format(origin, e.errno, e.reason)) + except HTTPError as e: + raise Exception(error_msg.format(origin, e.code, e.msg)) + except (Exception, KeyboardInterrupt) as e: + if os.path.exists(fpath): + os.remove(fpath) + raise + ProgressTracker.progbar = None + + if untar: + if not os.path.exists(untar_fpath): + _extract_archive(fpath, datadir, archive_format='tar') + return untar_fpath + + if extract: + _extract_archive(fpath, datadir, archive_format) + + return fpath + + +def _hash_file(fpath, algorithm='sha256', chunk_size=65535): + """Calculates a file sha256 or md5 hash. + + Example: + + ```python + >>> from keras.data_utils import _hash_file + >>> _hash_file('/path/to/file.zip') + 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + ``` + + Arguments: + fpath: path to the file being validated + algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'. + The default 'auto' detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. + + Returns: + The file hash + """ + if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64): + hasher = hashlib.sha256() + else: + hasher = hashlib.md5() + + with open(fpath, 'rb') as fpath_file: + for chunk in iter(lambda: fpath_file.read(chunk_size), b''): + hasher.update(chunk) + + return hasher.hexdigest() + + +def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): + """Validates a file against a sha256 or md5 hash. + + Arguments: + fpath: path to the file being validated + file_hash: The expected hash string of the file. + The sha256 and md5 hash algorithms are both supported. + algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. + The default 'auto' detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. + + Returns: + Whether the file is valid + """ + if ((algorithm is 'sha256') or + (algorithm is 'auto' and len(file_hash) is 64)): + hasher = 'sha256' + else: + hasher = 'md5' + + if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): + return True + else: + return False + + +@tf_export('keras.utils.Sequence') +class Sequence(object): + """Base object for fitting to a sequence of data, such as a dataset. + + Every `Sequence` must implements the `__getitem__` and the `__len__` methods. + If you want to modify your dataset between epochs you may implement + `on_epoch_end`. + The method `__getitem__` should return a complete batch. + + # Notes + + `Sequence` are a safer way to do multiprocessing. This structure guarantees + that the network will only train once + on each sample per epoch which is not the case with generators. + + Examples: + + ```python + from skimage.io import imread + from skimage.transform import resize + import numpy as np + import math + + # Here, `x_set` is list of path to the images + # and `y_set` are the associated classes. + + class CIFAR10Sequence(Sequence): + + def __init__(self, x_set, y_set, batch_size): + self.x, self.y = x_set, y_set + self.batch_size = batch_size + + def __len__(self): + return math.ceil(len(self.x) / self.batch_size) + + def __getitem__(self, idx): + batch_x = self.x[idx * self.batch_size:(idx + 1) * + self.batch_size] + batch_y = self.y[idx * self.batch_size:(idx + 1) * + self.batch_size] + + return np.array([ + resize(imread(file_name), (200, 200)) + for file_name in batch_x]), np.array(batch_y) + ``` + """ + + @abstractmethod + def __getitem__(self, index): + """Gets batch at position `index`. + + Arguments: + index: position of the batch in the Sequence. + + Returns: + A batch + """ + raise NotImplementedError + + @abstractmethod + def __len__(self): + """Number of batch in the Sequence. + + Returns: + The number of batches in the Sequence. + """ + raise NotImplementedError + + def on_epoch_end(self): + """Method called at the end of every epoch. + """ + pass + + def __iter__(self): + """Creates an infinite generator that iterate over the Sequence. + + Yields: + Sequence items. + """ + while True: + for item in (self[i] for i in range(len(self))): + yield item + + +# Global variables to be shared across processes +_SHARED_SEQUENCES = {} +# We use a Value to provide unique id to different processes. +_SEQUENCE_COUNTER = None + + +def init_pool(seqs): + global _SHARED_SEQUENCES + _SHARED_SEQUENCES = seqs + + +def get_index(uid, i): + """Get the value from the Sequence `uid` at index `i`. + + To allow multiple Sequences to be used at the same time, we use `uid` to + get a specific one. A single Sequence would cause the validation to + overwrite the training Sequence. + + Arguments: + uid: int, Sequence identifier + i: index + + Returns: + The value at index `i`. + """ + return _SHARED_SEQUENCES[uid][i] + + +@tf_export('keras.utils.SequenceEnqueuer') +class SequenceEnqueuer(object): + """Base class to enqueue inputs. + + The task of an Enqueuer is to use parallelism to speed up preprocessing. + This is done with processes or threads. + + Examples: + + ```python + enqueuer = SequenceEnqueuer(...) + enqueuer.start() + datas = enqueuer.get() + for data in datas: + # Use the inputs; training, evaluating, predicting. + # ... stop sometime. + enqueuer.close() + ``` + + The `enqueuer.get()` should be an infinite stream of datas. + + """ + + @abstractmethod + def is_running(self): + raise NotImplementedError + + @abstractmethod + def start(self, workers=1, max_queue_size=10): + """Starts the handler's workers. + + Arguments: + workers: number of worker threads + max_queue_size: queue size + (when full, threads could block on `put()`). + """ + raise NotImplementedError + + @abstractmethod + def stop(self, timeout=None): + """Stop running threads and wait for them to exit, if necessary. + + Should be called by the same thread which called start(). + + Arguments: + timeout: maximum time to wait on thread.join() + """ + raise NotImplementedError + + @abstractmethod + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + Returns: + Generator yielding tuples `(inputs, targets)` + or `(inputs, targets, sample_weights)`. + """ + raise NotImplementedError + + +class OrderedEnqueuer(SequenceEnqueuer): + """Builds a Enqueuer from a Sequence. + + Used in `fit_generator`, `evaluate_generator`, `predict_generator`. + + Arguments: + sequence: A `keras.utils.data_utils.Sequence` object. + use_multiprocessing: use multiprocessing if True, otherwise threading + shuffle: whether to shuffle the data at the beginning of each epoch + """ + + def __init__(self, sequence, use_multiprocessing=False, shuffle=False): + self.sequence = sequence + self.use_multiprocessing = use_multiprocessing + + global _SEQUENCE_COUNTER + if _SEQUENCE_COUNTER is None: + try: + _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) + except OSError: + # In this case the OS does not allow us to use + # multiprocessing. We resort to an int + # for enqueuer indexing. + _SEQUENCE_COUNTER = 0 + + if isinstance(_SEQUENCE_COUNTER, int): + self.uid = _SEQUENCE_COUNTER + _SEQUENCE_COUNTER += 1 + else: + # Doing Multiprocessing.Value += x is not process-safe. + with _SEQUENCE_COUNTER.get_lock(): + self.uid = _SEQUENCE_COUNTER.value + _SEQUENCE_COUNTER.value += 1 + + self.shuffle = shuffle + self.workers = 0 + self.executor_fn = None + self.queue = None + self.run_thread = None + self.stop_signal = None + + def is_running(self): + return self.stop_signal is not None and not self.stop_signal.is_set() + + def start(self, workers=1, max_queue_size=10): + """Start the handler's workers. + + Arguments: + workers: number of worker threads + max_queue_size: queue size + (when full, workers could block on `put()`) + """ + if self.use_multiprocessing: + self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda + workers, initializer=init_pool, initargs=(seqs,)) + else: + # We do not need the init since it's threads. + self.executor_fn = lambda _: ThreadPool(workers) + self.workers = workers + self.queue = queue.Queue(max_queue_size) + self.stop_signal = threading.Event() + self.run_thread = threading.Thread(target=self._run) + self.run_thread.daemon = True + self.run_thread.start() + + def _wait_queue(self): + """Wait for the queue to be empty.""" + while True: + time.sleep(0.1) + if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): + return + + def _run(self): + """Submits request to the executor and queue the `Future` objects.""" + sequence = list(range(len(self.sequence))) + self._send_sequence() # Share the initial sequence + while True: + if self.shuffle: + random.shuffle(sequence) + + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: + for i in sequence: + if self.stop_signal.is_set(): + return + self.queue.put( + executor.apply_async(get_index, (self.uid, i)), block=True) + + # Done with the current epoch, waiting for the final batches + self._wait_queue() + + if self.stop_signal.is_set(): + # We're done + return + + # Call the internal on epoch end. + self.sequence.on_epoch_end() + self._send_sequence() # Update the pool + + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + Yields: + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + """ + try: + while self.is_running(): + inputs = self.queue.get(block=True).get() + self.queue.task_done() + if inputs is not None: + yield inputs + except Exception as e: # pylint: disable=broad-except + self.stop() + six.raise_from(StopIteration(e), e) + + def _send_sequence(self): + """Send current Sequence to all workers.""" + # For new processes that may spawn + _SHARED_SEQUENCES[self.uid] = self.sequence + + def stop(self, timeout=None): + """Stops running threads and wait for them to exit, if necessary. + + Should be called by the same thread which called `start()`. + + Arguments: + timeout: maximum time to wait on `thread.join()` + """ + self.stop_signal.set() + with self.queue.mutex: + self.queue.queue.clear() + self.queue.unfinished_tasks = 0 + self.queue.not_full.notify() + self.run_thread.join(timeout) + _SHARED_SEQUENCES[self.uid] = None + + +@tf_export('keras.utils.GeneratorEnqueuer') +class GeneratorEnqueuer(SequenceEnqueuer): + """Builds a queue out of a data generator. + + The provided generator can be finite in which case the class will throw + a `StopIteration` exception. + + Used in `fit_generator`, `evaluate_generator`, `predict_generator`. + + Arguments: + generator: a generator function which yields data + use_multiprocessing: use multiprocessing if True, otherwise threading + wait_time: time to sleep in-between calls to `put()` + random_seed: Initial seed for workers, + will be incremented by one for each worker. + """ + + def __init__(self, + generator, + use_multiprocessing=False, + wait_time=0.05, + seed=None): + self.wait_time = wait_time + self._generator = generator + if os.name is 'nt' and use_multiprocessing is True: + # On Windows, avoid **SYSTEMATIC** error in `multiprocessing`: + # `TypeError: can't pickle generator objects` + # => Suggest multithreading instead of multiprocessing on Windows + raise ValueError('Using a generator with `use_multiprocessing=True`' + ' is not supported on Windows (no marshalling of' + ' generators across process boundaries). Instead,' + ' use single thread/process or multithreading.') + else: + self._use_multiprocessing = use_multiprocessing + self._threads = [] + self._stop_event = None + self._manager = None + self.queue = None + self.seed = seed + + def _data_generator_task(self): + if self._use_multiprocessing is False: + while not self._stop_event.is_set(): + with self.genlock: + try: + if (self.queue is not None and + self.queue.qsize() < self.max_queue_size): + # On all OSes, avoid **SYSTEMATIC** error + # in multithreading mode: + # `ValueError: generator already executing` + # => Serialize calls to + # infinite iterator/generator's next() function + generator_output = next(self._generator) + self.queue.put((True, generator_output)) + else: + time.sleep(self.wait_time) + except StopIteration: + break + except Exception as e: # pylint: disable=broad-except + # Can't pickle tracebacks. + # As a compromise, print the traceback and pickle None instead. + if not hasattr(e, '__traceback__'): + setattr(e, '__traceback__', sys.exc_info()[2]) + self.queue.put((False, e)) + self._stop_event.set() + break + else: + while not self._stop_event.is_set(): + try: + if (self.queue is not None and + self.queue.qsize() < self.max_queue_size): + generator_output = next(self._generator) + self.queue.put((True, generator_output)) + else: + time.sleep(self.wait_time) + except StopIteration: + break + except Exception as e: # pylint: disable=broad-except + # Can't pickle tracebacks. + # As a compromise, print the traceback and pickle None instead. + traceback.print_exc() + setattr(e, '__traceback__', None) + self.queue.put((False, e)) + self._stop_event.set() + break + + def start(self, workers=1, max_queue_size=10): + """Kicks off threads which add data from the generator into the queue. + + Arguments: + workers: number of worker threads + max_queue_size: queue size + (when full, threads could block on `put()`) + """ + try: + self.max_queue_size = max_queue_size + if self._use_multiprocessing: + self._manager = multiprocessing.Manager() + self.queue = self._manager.Queue(maxsize=max_queue_size) + self._stop_event = multiprocessing.Event() + else: + # On all OSes, avoid **SYSTEMATIC** error in multithreading mode: + # `ValueError: generator already executing` + # => Serialize calls to infinite iterator/generator's next() function + self.genlock = threading.Lock() + self.queue = queue.Queue(maxsize=max_queue_size) + self._stop_event = threading.Event() + + for _ in range(workers): + if self._use_multiprocessing: + # Reset random seed else all children processes + # share the same seed + np.random.seed(self.seed) + thread = multiprocessing.Process(target=self._data_generator_task) + thread.daemon = True + if self.seed is not None: + self.seed += 1 + else: + thread = threading.Thread(target=self._data_generator_task) + self._threads.append(thread) + thread.start() + except: + self.stop() + raise + + def is_running(self): + return self._stop_event is not None and not self._stop_event.is_set() + + def stop(self, timeout=None): + """Stops running threads and wait for them to exit, if necessary. + + Should be called by the same thread which called `start()`. + + Arguments: + timeout: maximum time to wait on `thread.join()`. + """ + if self.is_running(): + self._stop_event.set() + + for thread in self._threads: + if self._use_multiprocessing: + if thread.is_alive(): + thread.terminate() + else: + # The thread.is_alive() test is subject to a race condition: + # the thread could terminate right after the test and before the + # join, rendering this test meaningless -> Call thread.join() + # always, which is ok no matter what the status of the thread. + thread.join(timeout) + + if self._manager: + self._manager.shutdown() + + self._threads = [] + self._stop_event = None + self.queue = None + + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + Yields: + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + """ + while self.is_running(): + if not self.queue.empty(): + success, value = self.queue.get() + # Rethrow any exceptions found in the queue + if not success: + six.reraise(value.__class__, value, value.__traceback__) + # Yield regular values + if value is not None: + yield value + else: + all_finished = all([not thread.is_alive() for thread in self._threads]) + if all_finished and self.queue.empty(): + raise StopIteration() + else: + time.sleep(self.wait_time) + + # Make sure to rethrow the first exception in the queue, if any + while not self.queue.empty(): + success, value = self.queue.get() + if not success: + six.reraise(value.__class__, value, value.__traceback__) diff --git a/tensorflow/python/keras/utils/data_utils_test.py b/tensorflow/python/keras/utils/data_utils_test.py new file mode 100644 index 0000000000..395df7e0e7 --- /dev/null +++ b/tensorflow/python/keras/utils/data_utils_test.py @@ -0,0 +1,311 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for data_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from itertools import cycle +import os +import tarfile +import threading +import unittest +import zipfile + +import numpy as np +from six.moves.urllib.parse import urljoin +from six.moves.urllib.request import pathname2url + +from tensorflow.python import keras +from tensorflow.python.platform import test + + +class TestGetFileAndValidateIt(test.TestCase): + + def test_get_file_and_validate_it(self): + """Tests get_file from a url, plus extraction and validation. + """ + dest_dir = self.get_temp_dir() + orig_dir = self.get_temp_dir() + + text_file_path = os.path.join(orig_dir, 'test.txt') + zip_file_path = os.path.join(orig_dir, 'test.zip') + tar_file_path = os.path.join(orig_dir, 'test.tar.gz') + + with open(text_file_path, 'w') as text_file: + text_file.write('Float like a butterfly, sting like a bee.') + + with tarfile.open(tar_file_path, 'w:gz') as tar_file: + tar_file.add(text_file_path) + + with zipfile.ZipFile(zip_file_path, 'w') as zip_file: + zip_file.write(text_file_path) + + origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path))) + + path = keras.utils.data_utils.get_file('test.txt', origin, + untar=True, cache_subdir=dest_dir) + filepath = path + '.tar.gz' + hashval_sha256 = keras.utils.data_utils._hash_file(filepath) + hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5') + path = keras.utils.data_utils.get_file( + 'test.txt', origin, md5_hash=hashval_md5, + untar=True, cache_subdir=dest_dir) + path = keras.utils.data_utils.get_file( + filepath, origin, file_hash=hashval_sha256, + extract=True, cache_subdir=dest_dir) + self.assertTrue(os.path.exists(filepath)) + self.assertTrue(keras.utils.data_utils.validate_file(filepath, + hashval_sha256)) + self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5)) + os.remove(filepath) + + origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path))) + + hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path) + hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path, + algorithm='md5') + path = keras.utils.data_utils.get_file( + 'test', origin, md5_hash=hashval_md5, + extract=True, cache_subdir=dest_dir) + path = keras.utils.data_utils.get_file( + 'test', origin, file_hash=hashval_sha256, + extract=True, cache_subdir=dest_dir) + self.assertTrue(os.path.exists(path)) + self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256)) + self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5)) + + +class ThreadsafeIter(object): + + def __init__(self, it): + self.it = it + self.lock = threading.Lock() + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + with self.lock: + return next(self.it) + + +def threadsafe_generator(f): + + def g(*a, **kw): + return ThreadsafeIter(f(*a, **kw)) + + return g + + +class TestSequence(keras.utils.data_utils.Sequence): + + def __init__(self, shape, value=1.): + self.shape = shape + self.inner = value + + def __getitem__(self, item): + return np.ones(self.shape, dtype=np.uint32) * item * self.inner + + def __len__(self): + return 100 + + def on_epoch_end(self): + self.inner *= 5.0 + + +class FaultSequence(keras.utils.data_utils.Sequence): + + def __getitem__(self, item): + raise IndexError(item, 'item is not present') + + def __len__(self): + return 100 + + +@threadsafe_generator +def create_generator_from_sequence_threads(ds): + for i in cycle(range(len(ds))): + yield ds[i] + + +def create_generator_from_sequence_pcs(ds): + for i in cycle(range(len(ds))): + yield ds[i] + + +class TestEnqueuers(test.TestCase): + + def test_generator_enqueuer_threads(self): + enqueuer = keras.utils.data_utils.GeneratorEnqueuer( + create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])), + use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(int(next(gen_output)[0, 0, 0, 0])) + + self.assertEqual(len(set(acc) - set(range(100))), 0) + enqueuer.stop() + + @unittest.skipIf( + os.name == 'nt', + 'use_multiprocessing=True does not work on windows properly.') + def test_generator_enqueuer_processes(self): + enqueuer = keras.utils.data_utils.GeneratorEnqueuer( + create_generator_from_sequence_pcs(TestSequence([3, 200, 200, 3])), + use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(int(next(gen_output)[0, 0, 0, 0])) + self.assertNotEqual(acc, list(range(100))) + enqueuer.stop() + + def test_generator_enqueuer_fail_threads(self): + enqueuer = keras.utils.data_utils.GeneratorEnqueuer( + create_generator_from_sequence_threads(FaultSequence()), + use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + with self.assertRaises(IndexError): + next(gen_output) + + @unittest.skipIf( + os.name == 'nt', + 'use_multiprocessing=True does not work on windows properly.') + def test_generator_enqueuer_fail_processes(self): + enqueuer = keras.utils.data_utils.GeneratorEnqueuer( + create_generator_from_sequence_pcs(FaultSequence()), + use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + with self.assertRaises(IndexError): + next(gen_output) + + def test_ordered_enqueuer_threads(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + self.assertEqual(acc, list(range(100))) + enqueuer.stop() + + def test_ordered_enqueuer_processes(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + self.assertEqual(acc, list(range(100))) + enqueuer.stop() + + def test_ordered_enqueuer_fail_threads(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + FaultSequence(), use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + with self.assertRaises(StopIteration): + next(gen_output) + + def test_ordered_enqueuer_fail_processes(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + FaultSequence(), use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + with self.assertRaises(StopIteration): + next(gen_output) + + def test_on_epoch_end_processes(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(200): + acc.append(next(gen_output)[0, 0, 0, 0]) + # Check that order was keep in GeneratorEnqueuer with processes + self.assertEqual(acc[100:], list([k * 5 for k in range(100)])) + enqueuer.stop() + + def test_context_switch(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer2 = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True) + enqueuer.start(3, 10) + enqueuer2.start(3, 10) + gen_output = enqueuer.get() + gen_output2 = enqueuer2.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + self.assertEqual(acc[-1], 99) + # One epoch is completed so enqueuer will switch the Sequence + + acc = [] + for _ in range(100): + acc.append(next(gen_output2)[0, 0, 0, 0]) + self.assertEqual(acc[-1], 99 * 15) + # One epoch has been completed so enqueuer2 will switch + + # Be sure that both Sequence were updated + self.assertEqual(next(gen_output)[0, 0, 0, 0], 0) + self.assertEqual(next(gen_output)[0, 0, 0, 0], 5) + self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0) + self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5) + + # Tear down everything + enqueuer.stop() + enqueuer2.stop() + + def test_on_epoch_end_threads(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + # Check that order was keep in GeneratorEnqueuer with processes + self.assertEqual(acc, list([k * 5 for k in range(100)])) + enqueuer.stop() + + +if __name__ == '__main__': + # Bazel sets these environment variables to very long paths. + # Tempfile uses them to create long paths, and in turn multiprocessing + # library tries to create sockets named after paths. Delete whatever bazel + # writes to these to avoid tests failing due to socket addresses being too + # long. + for var in ('TMPDIR', 'TMP', 'TEMP'): + if var in os.environ: + del os.environ[var] + + test.main() diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py new file mode 100644 index 0000000000..a69893955f --- /dev/null +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -0,0 +1,561 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python utilities required by Keras.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import binascii +import codecs +import marshal +import os +import re +import sys +import time +import types as python_types + +import numpy as np +import six + +from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export + +_GLOBAL_CUSTOM_OBJECTS = {} + + +@tf_export('keras.utils.CustomObjectScope') +class CustomObjectScope(object): + """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. + + Code within a `with` statement will be able to access custom objects + by name. Changes to global custom objects persist + within the enclosing `with` statement. At end of the `with` statement, + global custom objects are reverted to state + at beginning of the `with` statement. + + Example: + + Consider a custom object `MyObject` (e.g. a class): + + ```python + with CustomObjectScope({'MyObject':MyObject}): + layer = Dense(..., kernel_regularizer='MyObject') + # save, load, etc. will recognize custom object by name + ``` + """ + + def __init__(self, *args): + self.custom_objects = args + self.backup = None + + def __enter__(self): + self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() + for objects in self.custom_objects: + _GLOBAL_CUSTOM_OBJECTS.update(objects) + return self + + def __exit__(self, *args, **kwargs): + _GLOBAL_CUSTOM_OBJECTS.clear() + _GLOBAL_CUSTOM_OBJECTS.update(self.backup) + + +@tf_export('keras.utils.custom_object_scope') +def custom_object_scope(*args): + """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. + + Convenience wrapper for `CustomObjectScope`. + Code within a `with` statement will be able to access custom objects + by name. Changes to global custom objects persist + within the enclosing `with` statement. At end of the `with` statement, + global custom objects are reverted to state + at beginning of the `with` statement. + + Example: + + Consider a custom object `MyObject` + + ```python + with custom_object_scope({'MyObject':MyObject}): + layer = Dense(..., kernel_regularizer='MyObject') + # save, load, etc. will recognize custom object by name + ``` + + Arguments: + *args: Variable length list of dictionaries of name, + class pairs to add to custom objects. + + Returns: + Object of type `CustomObjectScope`. + """ + return CustomObjectScope(*args) + + +@tf_export('keras.utils.get_custom_objects') +def get_custom_objects(): + """Retrieves a live reference to the global dictionary of custom objects. + + Updating and clearing custom objects using `custom_object_scope` + is preferred, but `get_custom_objects` can + be used to directly access `_GLOBAL_CUSTOM_OBJECTS`. + + Example: + + ```python + get_custom_objects().clear() + get_custom_objects()['MyObject'] = MyObject + ``` + + Returns: + Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). + """ + return _GLOBAL_CUSTOM_OBJECTS + + +@tf_export('keras.utils.serialize_keras_object') +def serialize_keras_object(instance): + _, instance = tf_decorator.unwrap(instance) + if instance is None: + return None + if hasattr(instance, 'get_config'): + return { + 'class_name': instance.__class__.__name__, + 'config': instance.get_config() + } + if hasattr(instance, '__name__'): + return instance.__name__ + else: + raise ValueError('Cannot serialize', instance) + + +@tf_export('keras.utils.deserialize_keras_object') +def deserialize_keras_object(identifier, + module_objects=None, + custom_objects=None, + printable_module_name='object'): + if isinstance(identifier, dict): + # In this case we are dealing with a Keras config dictionary. + config = identifier + if 'class_name' not in config or 'config' not in config: + raise ValueError('Improper config format: ' + str(config)) + class_name = config['class_name'] + if custom_objects and class_name in custom_objects: + cls = custom_objects[class_name] + elif class_name in _GLOBAL_CUSTOM_OBJECTS: + cls = _GLOBAL_CUSTOM_OBJECTS[class_name] + else: + module_objects = module_objects or {} + cls = module_objects.get(class_name) + if cls is None: + raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) + if hasattr(cls, 'from_config'): + arg_spec = tf_inspect.getargspec(cls.from_config) + custom_objects = custom_objects or {} + + if 'custom_objects' in arg_spec.args: + return cls.from_config( + config['config'], + custom_objects=dict( + list(_GLOBAL_CUSTOM_OBJECTS.items()) + + list(custom_objects.items()))) + with CustomObjectScope(custom_objects): + return cls.from_config(config['config']) + else: + # Then `cls` may be a function returning a class. + # in this case by convention `config` holds + # the kwargs of the function. + custom_objects = custom_objects or {} + with CustomObjectScope(custom_objects): + return cls(**config['config']) + elif isinstance(identifier, six.string_types): + function_name = identifier + if custom_objects and function_name in custom_objects: + fn = custom_objects.get(function_name) + elif function_name in _GLOBAL_CUSTOM_OBJECTS: + fn = _GLOBAL_CUSTOM_OBJECTS[function_name] + else: + fn = module_objects.get(function_name) + if fn is None: + raise ValueError('Unknown ' + printable_module_name + ':' + + function_name) + return fn + else: + raise ValueError('Could not interpret serialized ' + printable_module_name + + ': ' + identifier) + + +def func_dump(func): + """Serializes a user defined function. + + Arguments: + func: the function to serialize. + + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == 'nt': + raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/') + code = codecs.encode(raw_code, 'base64').decode('ascii') + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, 'base64').decode('ascii') + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. + + Arguments: + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. + + Returns: + A function object. + """ + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Arguments: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + def dummy_fn(): + # pylint: disable=pointless-statement + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + else: + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode('ascii'), 'base64') + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode('raw_unicode_escape') + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure) + + +def has_arg(fn, name, accept_all=False): + """Checks if a callable accepts a given keyword argument. + + Arguments: + fn: Callable to inspect. + name: Check if `fn` can be called with `name` as a keyword argument. + accept_all: What to return if there is no parameter called `name` + but the function accepts a `**kwargs` argument. + + Returns: + bool, whether `fn` accepts a `name` keyword argument. + """ + arg_spec = tf_inspect.getargspec(fn) + if accept_all and arg_spec.keywords is not None: + return True + return name in arg_spec.args + + +@tf_export('keras.utils.Progbar') +class Progbar(object): + """Displays a progress bar. + + Arguments: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over time. Metrics in this list + will be displayed as-is. All others will be averaged + by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + """ + + def __init__(self, target, width=30, verbose=1, interval=0.05, + stateful_metrics=None): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and + sys.stdout.isatty()) or + 'ipykernel' in sys.modules or + 'posix' in sys.modules) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None): + """Updates the progress bar. + + Arguments: + current: Index of current step. + values: List of tuples: + `(name, value_for_last_step)`. + If `name` is in `stateful_metrics`, + `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. + """ + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + if k not in self._values: + self._values[k] = [v * (current - self._seen_so_far), + current - self._seen_so_far] + else: + self._values[k][0] += v * (current - self._seen_so_far) + self._values[k][1] += (current - self._seen_so_far) + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + now = time.time() + info = ' - %.0fs' % (now - self._start) + if self.verbose == 1: + if (now - self._last_update < self.interval and + self.target is not None and current < self.target): + return + + prev_total_width = self._total_width + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + + if self.target is not None: + numdigits = int(np.floor(np.log10(self.target))) + 1 + barstr = '%%%dd/%d [' % (numdigits, self.target) + bar = barstr % current + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += ('=' * (prog_width - 1)) + if current < self.target: + bar += '>' + else: + bar += '=' + bar += ('.' * (self.width - prog_width)) + bar += ']' + else: + bar = '%7d/Unknown' % current + + self._total_width = len(bar) + sys.stdout.write(bar) + + if current: + time_per_unit = (now - self._start) / current + else: + time_per_unit = 0 + if self.target is not None and current < self.target: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, + eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format + else: + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) + else: + info += ' %.0fus/step' % (time_per_unit * 1e6) + + for k in self._values_order: + info += ' - %s:' % k + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if abs(avg) > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + else: + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) + + if self.target is not None and current >= self.target: + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + elif self.verbose == 2: + if self.target is None or current >= self.target: + for k in self._values_order: + info += ' - %s:' % k + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if avg > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Arguments: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [(i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches)] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Arguments: + arrays: Single array or list of arrays. + start: can be an integer index (start index) + or a list/array of indices + stop: integer (stop index); should be None if + `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError('The stop argument has to be None if the value of start ' + 'is a list.') + elif isinstance(arrays, list): + if hasattr(start, '__len__'): + # hdf5 datasets only support list objects as indices + if hasattr(start, 'shape'): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + else: + return [None if x is None else x[start:stop] for x in arrays] + else: + if hasattr(start, '__len__'): + if hasattr(start, 'shape'): + start = start.tolist() + return arrays[start] + elif hasattr(start, '__getitem__'): + return arrays[start:stop] + else: + return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Arguments: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] + + +def object_list_uid(object_list): + """Creates a single string from object ids.""" + object_list = nest.flatten(object_list) + return ', '.join([str(abs(id(x))) for x in object_list]) + + +def to_snake_case(name): + intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) + insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != '_': + return insecure + return 'private' + insecure + + +def is_all_none(iterable_or_element): + if not isinstance(iterable_or_element, (list, tuple)): + iterable = [iterable_or_element] + else: + iterable = iterable_or_element + # We cannot use Python's `any` because the iterable may return Tensors. + for element in iterable: + if element is not None: + return False + return True diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py new file mode 100644 index 0000000000..87bc19eb37 --- /dev/null +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -0,0 +1,75 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras generic Python utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import keras +from tensorflow.python.platform import test + + +class HasArgTest(test.TestCase): + + def test_has_arg(self): + + def f_x(x): + return x + + def f_x_args(x, *args): + _ = args + return x + + def f_x_kwargs(x, **kwargs): + _ = kwargs + return x + + self.assertTrue(keras.utils.generic_utils.has_arg( + f_x, 'x', accept_all=False)) + self.assertFalse(keras.utils.generic_utils.has_arg( + f_x, 'y', accept_all=False)) + self.assertTrue(keras.utils.generic_utils.has_arg( + f_x_args, 'x', accept_all=False)) + self.assertFalse(keras.utils.generic_utils.has_arg( + f_x_args, 'y', accept_all=False)) + self.assertTrue(keras.utils.generic_utils.has_arg( + f_x_kwargs, 'x', accept_all=False)) + self.assertFalse(keras.utils.generic_utils.has_arg( + f_x_kwargs, 'y', accept_all=False)) + self.assertTrue(keras.utils.generic_utils.has_arg( + f_x_kwargs, 'y', accept_all=True)) + + +class TestCustomObjectScope(test.TestCase): + + def test_custom_object_scope(self): + + def custom_fn(): + pass + + class CustomClass(object): + pass + + with keras.utils.generic_utils.custom_object_scope( + {'CustomClass': CustomClass, 'custom_fn': custom_fn}): + act = keras.activations.get('custom_fn') + self.assertEqual(act, custom_fn) + cl = keras.regularizers.get('CustomClass') + self.assertEqual(cl.__class__, CustomClass) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py new file mode 100644 index 0000000000..f82e3277de --- /dev/null +++ b/tensorflow/python/keras/utils/io_utils.py @@ -0,0 +1,171 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-import-not-at-top +"""Utilities related to disk I/O.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import numpy as np +import six +from tensorflow.python.util.tf_export import tf_export + + +try: + import h5py +except ImportError: + h5py = None + + +@tf_export('keras.utils.HDF5Matrix') +class HDF5Matrix(object): + """Representation of HDF5 dataset to be used instead of a Numpy array. + + Example: + + ```python + x_data = HDF5Matrix('input/file.hdf5', 'data') + model.predict(x_data) + ``` + + Providing `start` and `end` allows use of a slice of the dataset. + + Optionally, a normalizer function (or lambda) can be given. This will + be called on every slice of data retrieved. + + Arguments: + datapath: string, path to a HDF5 file + dataset: string, name of the HDF5 dataset in the file specified + in datapath + start: int, start of desired slice of the specified dataset + end: int, end of desired slice of the specified dataset + normalizer: function to be called on data when retrieved + + Returns: + An array-like HDF5 dataset. + """ + refs = defaultdict(int) + + def __init__(self, datapath, dataset, start=0, end=None, normalizer=None): + if h5py is None: + raise ImportError('The use of HDF5Matrix requires ' + 'HDF5 and h5py installed.') + + if datapath not in list(self.refs.keys()): + f = h5py.File(datapath) + self.refs[datapath] = f + else: + f = self.refs[datapath] + self.data = f[dataset] + self.start = start + if end is None: + self.end = self.data.shape[0] + else: + self.end = end + self.normalizer = normalizer + + def __len__(self): + return self.end - self.start + + def __getitem__(self, key): + if isinstance(key, slice): + start, stop = key.start, key.stop + if start is None: + start = 0 + if stop is None: + stop = self.shape[0] + if stop + self.start <= self.end: + idx = slice(start + self.start, stop + self.start) + else: + raise IndexError + elif isinstance(key, (int, np.integer)): + if key + self.start < self.end: + idx = key + self.start + else: + raise IndexError + elif isinstance(key, np.ndarray): + if np.max(key) + self.start < self.end: + idx = (self.start + key).tolist() + else: + raise IndexError + elif isinstance(key, list): + if max(key) + self.start < self.end: + idx = [x + self.start for x in key] + else: + raise IndexError + else: + raise IndexError + if self.normalizer is not None: + return self.normalizer(self.data[idx]) + else: + return self.data[idx] + + @property + def shape(self): + """Gets a numpy-style shape tuple giving the dataset dimensions. + + Returns: + A numpy-style shape tuple. + """ + return (self.end - self.start,) + self.data.shape[1:] + + @property + def dtype(self): + """Gets the datatype of the dataset. + + Returns: + A numpy dtype string. + """ + return self.data.dtype + + @property + def ndim(self): + """Gets the number of dimensions (rank) of the dataset. + + Returns: + An integer denoting the number of dimensions (rank) of the dataset. + """ + return self.data.ndim + + @property + def size(self): + """Gets the total dataset size (number of elements). + + Returns: + An integer denoting the number of elements in the dataset. + """ + return np.prod(self.shape) + + +def ask_to_proceed_with_overwrite(filepath): + """Produces a prompt asking about overwriting a file. + + Arguments: + filepath: the path to the file to be overwritten. + + Returns: + True if we can proceed with overwrite, False otherwise. + """ + overwrite = six.moves.input('[WARNING] %s already exists - overwrite? ' + '[y/n]' % (filepath)).strip().lower() + while overwrite not in ('y', 'n'): + overwrite = six.moves.input('Enter "y" (overwrite) or "n" ' + '(cancel).').strip().lower() + if overwrite == 'n': + return False + print('[TIP] Next time specify overwrite=True!') + return True diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py new file mode 100644 index 0000000000..3895dca68e --- /dev/null +++ b/tensorflow/python/keras/utils/io_utils_test.py @@ -0,0 +1,100 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for io_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.platform import test + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +def create_dataset(h5_path='test.h5'): + x = np.random.randn(200, 10).astype('float32') + y = np.random.randint(0, 2, size=(200, 1)) + f = h5py.File(h5_path, 'w') + # Creating dataset to store features + x_dset = f.create_dataset('my_data', (200, 10), dtype='f') + x_dset[:] = x + # Creating dataset to store labels + y_dset = f.create_dataset('my_labels', (200, 1), dtype='i') + y_dset[:] = y + f.close() + + +class TestIOUtils(test.TestCase): + + def test_HDF5Matrix(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + h5_path = os.path.join(temp_dir, 'test.h5') + create_dataset(h5_path) + + # Instantiating HDF5Matrix for the training set, + # which is a slice of the first 150 elements + x_train = keras.utils.io_utils.HDF5Matrix( + h5_path, 'my_data', start=0, end=150) + y_train = keras.utils.io_utils.HDF5Matrix( + h5_path, 'my_labels', start=0, end=150) + + # Likewise for the test set + x_test = keras.utils.io_utils.HDF5Matrix( + h5_path, 'my_data', start=150, end=200) + y_test = keras.utils.io_utils.HDF5Matrix( + h5_path, 'my_labels', start=150, end=200) + + # HDF5Matrix behave more or less like Numpy matrices + # with regard to indexing + self.assertEqual(y_train.shape, (150, 1)) + # But they do not support negative indices, so don't try print(x_train[-1]) + + self.assertEqual(y_train.dtype, np.dtype('i')) + self.assertEqual(y_train.ndim, 2) + self.assertEqual(y_train.size, 150) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu')) + model.add(keras.layers.Dense(1, activation='sigmoid')) + model.compile(loss='binary_crossentropy', optimizer='sgd') + + # Note: you have to use shuffle='batch' or False with HDF5Matrix + model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False) + # test that evalutation and prediction + # don't crash and return reasonable results + out_pred = model.predict(x_test, batch_size=32, verbose=False) + out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False) + + self.assertEqual(out_pred.shape, (50, 1)) + self.assertEqual(out_eval.shape, ()) + self.assertGreater(out_eval, 0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py new file mode 100644 index 0000000000..bd61f8e9cc --- /dev/null +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -0,0 +1,266 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Utilities related to layer/model functionality. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.conv_utils import convert_kernel +from tensorflow.python.util.tf_export import tf_export + + +def count_params(weights): + """Count the total number of scalars composing the weights. + + Arguments: + weights: An iterable containing the weights on which to compute params + + Returns: + The total number of scalars composing the weights + """ + return int(np.sum([np.prod(p.get_shape().as_list()) for p in set(weights)])) + + +def print_summary(model, line_length=None, positions=None, print_fn=None): + """Prints a summary of a model. + + Arguments: + model: Keras model instance. + line_length: Total length of printed lines + (e.g. set this to adapt the display to different + terminal window sizes). + positions: Relative or absolute positions of log elements in each line. + If not provided, defaults to `[.33, .55, .67, 1.]`. + print_fn: Print function to use. + It will be called on each line of the summary. + You can set it to a custom function + in order to capture the string summary. + It defaults to `print` (prints to stdout). + """ + if print_fn is None: + print_fn = print + + if model.__class__.__name__ == 'Sequential': + sequential_like = True + elif not model._is_graph_network: + # We treat subclassed models as a simple sequence of layers, for logging + # purposes. + sequential_like = True + else: + sequential_like = True + nodes_by_depth = model._nodes_by_depth.values() + nodes = [] + for v in nodes_by_depth: + if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1): + # if the model has multiple nodes + # or if the nodes have multiple inbound_layers + # the model is no longer sequential + sequential_like = False + break + nodes += v + if sequential_like: + # search for shared layers + for layer in model.layers: + flag = False + for node in layer._inbound_nodes: + if node in nodes: + if flag: + sequential_like = False + break + else: + flag = True + if not sequential_like: + break + + if sequential_like: + line_length = line_length or 65 + positions = positions or [.45, .85, 1.] + if positions[-1] <= 1: + positions = [int(line_length * p) for p in positions] + # header names for the different log elements + to_display = ['Layer (type)', 'Output Shape', 'Param #'] + else: + line_length = line_length or 98 + positions = positions or [.33, .55, .67, 1.] + if positions[-1] <= 1: + positions = [int(line_length * p) for p in positions] + # header names for the different log elements + to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] + relevant_nodes = [] + for v in model._nodes_by_depth.values(): + relevant_nodes += v + + def print_row(fields, positions): + line = '' + for i in range(len(fields)): + if i > 0: + line = line[:-1] + ' ' + line += str(fields[i]) + line = line[:positions[i]] + line += ' ' * (positions[i] - len(line)) + print_fn(line) + + print_fn('_' * line_length) + print_row(to_display, positions) + print_fn('=' * line_length) + + def print_layer_summary(layer): + """Prints a summary for a single layer. + + Arguments: + layer: target layer. + """ + try: + output_shape = layer.output_shape + except AttributeError: + output_shape = 'multiple' + except RuntimeError: # output_shape unknown in Eager mode. + output_shape = '?' + name = layer.name + cls_name = layer.__class__.__name__ + fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] + print_row(fields, positions) + + def print_layer_summary_with_connections(layer): + """Prints a summary for a single layer (including topological connections). + + Arguments: + layer: target layer. + """ + try: + output_shape = layer.output_shape + except AttributeError: + output_shape = 'multiple' + connections = [] + for node in layer._inbound_nodes: + if relevant_nodes and node not in relevant_nodes: + # node is not part of the current network + continue + for i in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[i].name + inbound_node_index = node.node_indices[i] + inbound_tensor_index = node.tensor_indices[i] + connections.append(inbound_layer + '[' + str(inbound_node_index) + + '][' + str(inbound_tensor_index) + ']') + + name = layer.name + cls_name = layer.__class__.__name__ + if not connections: + first_connection = '' + else: + first_connection = connections[0] + fields = [ + name + ' (' + cls_name + ')', output_shape, + layer.count_params(), first_connection + ] + print_row(fields, positions) + if len(connections) > 1: + for i in range(1, len(connections)): + fields = ['', '', '', connections[i]] + print_row(fields, positions) + + layers = model.layers + for i in range(len(layers)): + if sequential_like: + print_layer_summary(layers[i]) + else: + print_layer_summary_with_connections(layers[i]) + if i == len(layers) - 1: + print_fn('=' * line_length) + else: + print_fn('_' * line_length) + + model._check_trainable_weights_consistency() + if hasattr(model, '_collected_trainable_weights'): + trainable_count = count_params(model._collected_trainable_weights) + else: + trainable_count = count_params(model.trainable_weights) + + non_trainable_count = count_params(model.non_trainable_weights) + + print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) + print_fn('Trainable params: {:,}'.format(trainable_count)) + print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) + print_fn('_' * line_length) + + +@tf_export('keras.utils.convert_all_kernels_in_model') +def convert_all_kernels_in_model(model): + """Converts all convolution kernels in a model from Theano to TensorFlow. + + Also works from TensorFlow to Theano. + + Arguments: + model: target model for the conversion. + """ + # Note: SeparableConvolution not included + # since only supported by TF. + conv_classes = { + 'Conv1D', + 'Conv2D', + 'Conv3D', + 'Conv2DTranspose', + } + to_assign = [] + for layer in model.layers: + if layer.__class__.__name__ in conv_classes: + original_kernel = K.get_value(layer.kernel) + converted_kernel = convert_kernel(original_kernel) + to_assign.append((layer.kernel, converted_kernel)) + K.batch_set_value(to_assign) + + +def convert_dense_weights_data_format(dense, + previous_feature_map_shape, + target_data_format='channels_first'): + """Utility useful when changing a convnet's `data_format`. + + When porting the weights of a convnet from one data format to the other, + if the convnet includes a `Flatten` layer + (applied to the last convolutional feature map) + followed by a `Dense` layer, the weights of that `Dense` layer + should be updated to reflect the new dimension ordering. + + Arguments: + dense: The target `Dense` layer. + previous_feature_map_shape: A shape tuple of 3 integers, + e.g. `(512, 7, 7)`. The shape of the convolutional + feature map right before the `Flatten` layer that + came before the target `Dense` layer. + target_data_format: One of "channels_last", "channels_first". + Set it "channels_last" + if converting a "channels_first" model to "channels_last", + or reciprocally. + """ + assert target_data_format in {'channels_last', 'channels_first'} + kernel, bias = dense.get_weights() + for i in range(kernel.shape[1]): + if target_data_format == 'channels_first': + c, h, w = previous_feature_map_shape + original_fm_shape = (h, w, c) + ki = kernel[:, i].reshape(original_fm_shape) + ki = np.transpose(ki, (2, 0, 1)) # last -> first + else: + h, w, c = previous_feature_map_shape + original_fm_shape = (c, h, w) + ki = kernel[:, i].reshape(original_fm_shape) + ki = np.transpose(ki, (1, 2, 0)) # first -> last + kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) + dense.set_weights([kernel, bias]) diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py new file mode 100644 index 0000000000..e5442f04e3 --- /dev/null +++ b/tensorflow/python/keras/utils/multi_gpu_utils.py @@ -0,0 +1,252 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for multi-gpu training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.training import Model +from tensorflow.python.ops import array_ops +from tensorflow.python.util.tf_export import tf_export + + +def _get_available_devices(): + return [x.name for x in K.get_session().list_devices()] + + +def _normalize_device_name(name): + name = '/' + name.lower().split('device:')[1] + return name + + +@tf_export('keras.utils.multi_gpu_model') +def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False): + """Replicates a model on different GPUs. + + Specifically, this function implements single-machine + multi-GPU data parallelism. It works in the following way: + + - Divide the model's input(s) into multiple sub-batches. + - Apply a model copy on each sub-batch. Every model copy + is executed on a dedicated GPU. + - Concatenate the results (on CPU) into one big batch. + + E.g. if your `batch_size` is 64 and you use `gpus=2`, + then we will divide the input into 2 sub-batches of 32 samples, + process each sub-batch on one GPU, then return the full + batch of 64 processed samples. + + This induces quasi-linear speedup on up to 8 GPUs. + + This function is only available with the TensorFlow backend + for the time being. + + Arguments: + model: A Keras model instance. To avoid OOM errors, + this model could have been built on CPU, for instance + (see usage example below). + gpus: Integer >= 2, number of on GPUs on which to create + model replicas. + cpu_merge: A boolean value to identify whether to force + merging model weights under the scope of the CPU or not. + cpu_relocation: A boolean value to identify whether to + create the model's weights under the scope of the CPU. + If the model is not defined under any preceding device + scope, you can still rescue it by activating this option. + + Returns: + A Keras `Model` instance which can be used just like the initial + `model` argument, but which distributes its workload on multiple GPUs. + + Example 1: Training models with weights merge on CPU + + ```python + import tensorflow as tf + from keras.applications import Xception + from keras.utils import multi_gpu_model + import numpy as np + + num_samples = 1000 + height = 224 + width = 224 + num_classes = 1000 + + # Instantiate the base model (or "template" model). + # We recommend doing this with under a CPU device scope, + # so that the model's weights are hosted on CPU memory. + # Otherwise they may end up hosted on a GPU, which would + # complicate weight sharing. + with tf.device('/cpu:0'): + model = Xception(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + + # Replicates the model on 8 GPUs. + # This assumes that your machine has 8 available GPUs. + parallel_model = multi_gpu_model(model, gpus=8) + parallel_model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + # Generate dummy data. + x = np.random.random((num_samples, height, width, 3)) + y = np.random.random((num_samples, num_classes)) + + # This `fit` call will be distributed on 8 GPUs. + # Since the batch size is 256, each GPU will process 32 samples. + parallel_model.fit(x, y, epochs=20, batch_size=256) + + # Save model via the template model (which shares the same weights): + model.save('my_model.h5') + ``` + + Example 2: Training models with weights merge on CPU using cpu_relocation + + ```python + .. + # Not needed to change the device scope for model definition: + model = Xception(weights=None, ..) + + try: + model = multi_gpu_model(model, cpu_relocation=True) + print("Training using multiple GPUs..") + except: + print("Training using single GPU or CPU..") + + model.compile(..) + .. + ``` + + Example 3: Training models with weights merge on GPU (recommended for NV-link) + + ```python + .. + # Not needed to change the device scope for model definition: + model = Xception(weights=None, ..) + + try: + model = multi_gpu_model(model, cpu_merge=False) + print("Training using multiple GPUs..") + except: + print("Training using single GPU or CPU..") + model.compile(..) + .. + ``` + + Raises: + ValueError: if the `gpus` argument does not match available devices. + """ + # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers.core import Lambda + from tensorflow.python.keras.layers.merge import concatenate + + if isinstance(gpus, (list, tuple)): + if len(gpus) <= 1: + raise ValueError('For multi-gpu usage to be effective, ' + 'call `multi_gpu_model` with `len(gpus) >= 2`. ' + 'Received: `gpus=%s`' % gpus) + num_gpus = len(gpus) + target_gpu_ids = gpus + else: + if gpus <= 1: + raise ValueError('For multi-gpu usage to be effective, ' + 'call `multi_gpu_model` with `gpus >= 2`. ' + 'Received: `gpus=%s`' % gpus) + num_gpus = gpus + target_gpu_ids = range(num_gpus) + + target_devices = ['/cpu:0'] + ['/gpu:%d' % i for i in target_gpu_ids] + available_devices = _get_available_devices() + available_devices = [ + _normalize_device_name(name) for name in available_devices + ] + for device in target_devices: + if device not in available_devices: + raise ValueError('To call `multi_gpu_model` with `gpus=%s`, ' + 'we expect the following devices to be available: %s. ' + 'However this machine only has: %s. ' + 'Try reducing `gpus`.' % (gpus, target_devices, + available_devices)) + + def get_slice(data, i, parts): + """Slice an array into `parts` slices and return slice `i`. + + Arguments: + data: array to slice. + i: index of slice to return. + parts: number of slices to make. + + Returns: + Slice `i` of `data`. + """ + shape = array_ops.shape(data) + batch_size = shape[:1] + input_shape = shape[1:] + step = batch_size // parts + if i == num_gpus - 1: + size = batch_size - step * i + else: + size = step + size = array_ops.concat([size, input_shape], axis=0) + stride = array_ops.concat([step, input_shape * 0], axis=0) + start = stride * i + return array_ops.slice(data, start, size) + + # Relocate the model definition under CPU device scope if needed + if cpu_relocation: + from tensorflow.python.keras.models import clone_model # pylint: disable=g-import-not-at-top + with ops.device('/cpu:0'): + model = clone_model(model) + + all_outputs = [] + for i in range(len(model.outputs)): + all_outputs.append([]) + + # Place a copy of the model on each GPU, + # each getting a slice of the inputs. + for i, gpu_id in enumerate(target_gpu_ids): + with ops.device('/gpu:%d' % gpu_id): + with ops.name_scope('replica_%d' % gpu_id): + inputs = [] + # Retrieve a slice of the input. + for x in model.inputs: + input_shape = tuple(x.get_shape().as_list())[1:] + slice_i = Lambda( + get_slice, + output_shape=input_shape, + arguments={ + 'i': i, + 'parts': num_gpus + })( + x) + inputs.append(slice_i) + + # Apply model on slice + # (creating a model replica on the target device). + outputs = model(inputs) + if not isinstance(outputs, list): + outputs = [outputs] + + # Save the outputs for merging back together later. + for o in range(len(outputs)): + all_outputs[o].append(outputs[o]) + + # Merge outputs under expected scope. + with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]): + merged = [] + for name, outputs in zip(model.output_names, all_outputs): + merged.append(concatenate(outputs, axis=0, name=name)) + return Model(model.inputs, merged) diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py new file mode 100644 index 0000000000..77792d14f5 --- /dev/null +++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py @@ -0,0 +1,185 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multi-gpu training utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python import data +from tensorflow.python import keras +from tensorflow.python.platform import test + + +def check_if_compatible_devices(gpus=2): + available_devices = [ + keras.utils.multi_gpu_utils._normalize_device_name(name) + for name in keras.utils.multi_gpu_utils._get_available_devices() + ] + if '/gpu:%d' % (gpus - 1) not in available_devices: + return False + return True + + +class TestMultiGPUModel(test.TestCase): + + def test_multi_gpu_test_simple_model(self): + gpus = 2 + num_samples = 1000 + input_dim = 10 + output_dim = 1 + hidden_dim = 10 + epochs = 2 + target_gpu_id = [0, 1] + + if not check_if_compatible_devices(gpus=gpus): + return + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(hidden_dim, + input_shape=(input_dim,))) + model.add(keras.layers.Dense(output_dim)) + + x = np.random.random((num_samples, input_dim)) + y = np.random.random((num_samples, output_dim)) + + parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus) + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit(x, y, epochs=epochs) + parallel_model = keras.utils.multi_gpu_model(model, gpus=target_gpu_id) + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit(x, y, epochs=epochs) + + def test_multi_gpu_test_multi_io_model(self): + gpus = 2 + num_samples = 1000 + input_dim_a = 10 + input_dim_b = 5 + output_dim_a = 1 + output_dim_b = 2 + hidden_dim = 10 + epochs = 2 + target_gpu_id = [0, 1] + + if not check_if_compatible_devices(gpus=gpus): + return + + with self.test_session(): + input_a = keras.Input((input_dim_a,)) + input_b = keras.Input((input_dim_b,)) + a = keras.layers.Dense(hidden_dim)(input_a) + b = keras.layers.Dense(hidden_dim)(input_b) + c = keras.layers.concatenate([a, b]) + output_a = keras.layers.Dense(output_dim_a)(c) + output_b = keras.layers.Dense(output_dim_b)(c) + model = keras.models.Model([input_a, input_b], [output_a, output_b]) + + a_x = np.random.random((num_samples, input_dim_a)) + b_x = np.random.random((num_samples, input_dim_b)) + a_y = np.random.random((num_samples, output_dim_a)) + b_y = np.random.random((num_samples, output_dim_b)) + + parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus) + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs) + + parallel_model = keras.utils.multi_gpu_model(model, gpus=target_gpu_id) + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs) + + def test_multi_gpu_test_invalid_devices(self): + if not check_if_compatible_devices(gpus=2): + return + + with self.test_session(): + input_shape = (1000, 10) + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, + activation='relu', + input_shape=input_shape[1:])) + model.add(keras.layers.Dense(1, activation='sigmoid')) + model.compile(loss='mse', optimizer='rmsprop') + + x = np.random.random(input_shape) + y = np.random.random((input_shape[0], 1)) + with self.assertRaises(ValueError): + parallel_model = keras.utils.multi_gpu_model( + model, gpus=len(keras.backend._get_available_gpus()) + 1) + parallel_model.fit(x, y, epochs=2) + + with self.assertRaises(ValueError): + parallel_model = keras.utils.multi_gpu_model( + model, gpus=[0, 2, 4, 6, 8]) + parallel_model.fit(x, y, epochs=2) + + with self.assertRaises(ValueError): + parallel_model = keras.utils.multi_gpu_model(model, gpus=1) + parallel_model.fit(x, y, epochs=2) + + with self.assertRaises(ValueError): + parallel_model = keras.utils.multi_gpu_model(model, gpus=[0]) + parallel_model.fit(x, y, epochs=2) + + def test_nested_model_with_tensor_input(self): + gpus = 2 + input_dim = 10 + shape = (input_dim,) + num_samples = 16 + num_classes = 10 + + if not check_if_compatible_devices(gpus=gpus): + return + + with self.test_session(): + input_shape = (num_samples,) + shape + x_train = np.random.randint(0, 255, input_shape) + y_train = np.random.randint(0, num_classes, (input_shape[0],)) + keras.backend.set_learning_phase(True) + + y_train = keras.utils.to_categorical(y_train, num_classes) + + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + dataset = data.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.repeat() + dataset = dataset.batch(4) + iterator = dataset.make_one_shot_iterator() + + inputs, targets = iterator.get_next() + + input_tensor = keras.layers.Input(tensor=inputs) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(3, + input_shape=(input_dim,))) + model.add(keras.layers.Dense(num_classes)) + + output = model(input_tensor) + outer_model = keras.Model(input_tensor, output) + parallel_model = keras.utils.multi_gpu_model(outer_model, gpus=gpus) + + parallel_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.RMSprop(lr=0.0001, decay=1e-6), + metrics=['accuracy'], + target_tensors=[targets]) + parallel_model.fit(epochs=1, steps_per_epoch=3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py new file mode 100644 index 0000000000..9d9c72b162 --- /dev/null +++ b/tensorflow/python/keras/utils/np_utils.py @@ -0,0 +1,67 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Numpy-related utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.util.tf_export import tf_export + + +@tf_export('keras.utils.to_categorical') +def to_categorical(y, num_classes=None): + """Converts a class vector (integers) to binary class matrix. + + E.g. for use with categorical_crossentropy. + + Arguments: + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + + Returns: + A binary matrix representation of the input. + """ + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() + if not num_classes: + num_classes = np.max(y) + 1 + n = y.shape[0] + categorical = np.zeros((n, num_classes), dtype=np.float32) + categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +@tf_export('keras.utils.normalize') +def normalize(x, axis=-1, order=2): + """Normalizes a Numpy array. + + Arguments: + x: Numpy array to normalize. + axis: axis along which to normalize. + order: Normalization order (e.g. 2 for L2 norm). + + Returns: + A normalized copy of the array. + """ + l2 = np.atleast_1d(np.linalg.norm(x, order, axis)) + l2[l2 == 0] = 1 + return x / np.expand_dims(l2, axis) diff --git a/tensorflow/python/keras/utils/np_utils_test.py b/tensorflow/python/keras/utils/np_utils_test.py new file mode 100644 index 0000000000..d77e76ff3e --- /dev/null +++ b/tensorflow/python/keras/utils/np_utils_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for np_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.platform import test + + +class TestNPUtils(test.TestCase): + + def test_to_categorical(self): + num_classes = 5 + shapes = [(1,), (3,), (4, 3), (5, 4, 3), (3, 1), (3, 2, 1)] + expected_shapes = [(1, num_classes), + (3, num_classes), + (4, 3, num_classes), + (5, 4, 3, num_classes), + (3, num_classes)] + labels = [np.random.randint(0, num_classes, shape) for shape in shapes] + one_hots = [ + keras.utils.to_categorical(label, num_classes) for label in labels] + for label, one_hot, expected_shape in zip(labels, + one_hots, + expected_shapes): + # Check shape + self.assertEqual(one_hot.shape, expected_shape) + # Make sure there is only one 1 in a row + self.assertTrue(np.all(one_hot.sum(axis=-1) == 1)) + # Get original labels back from one hots + self.assertTrue(np.all( + np.argmax(one_hot, -1).reshape(label.shape) == label)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py new file mode 100644 index 0000000000..162e5b2cd6 --- /dev/null +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -0,0 +1,154 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow-related utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond as smart_module +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.util import nest + + +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. + + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `true_fn` or `false_fn`. + + Raises: + TypeError: If `true_fn` or `false_fn` is not callable. + """ + if isinstance(pred, variables.Variable): + return control_flow_ops.cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) + return smart_module.smart_cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) + + +def constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or a TensorFlow boolean variable + or tensor, or the Python integer 1 or 0. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError: If `pred` is not a Variable, Tensor or bool, or Python + integer 1 or 0. + """ + # Allow integer booleans. + if isinstance(pred, int): + if pred == 1: + pred = True + elif pred == 0: + pred = False + + if isinstance(pred, variables.Variable): + return None + return smart_module.smart_constant_value(pred) + + +def is_tensor_or_tensor_list(v): + v = nest.flatten(v) + if v and isinstance(v[0], ops.Tensor): + return True + else: + return False + + +def get_reachable_from_inputs(inputs, targets=None): + """Returns the set of tensors/ops reachable from `inputs`. + + Stops if all targets have been found (target is optional). + + Only valid in Symbolic mode, not Eager mode. + + Args: + inputs: List of tensors. + targets: List of tensors. + + Returns: + A set of tensors reachable from the inputs (includes the inputs themselves). + """ + reachable = set(inputs) + if targets: + targets = set(targets) + queue = inputs[:] + + while queue: + x = queue.pop() + if isinstance(x, ops.Operation): + outputs = x.outputs[:] or [] + outputs += x._control_outputs # pylint: disable=protected-access + elif isinstance(x, ops.Tensor): + outputs = x.consumers() + elif isinstance(x, variables.Variable): + outputs = [x.op] + else: + raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) + + for y in outputs: + if y not in reachable: + reachable.add(y) + queue.insert(0, y) + + if targets and targets.issubset(reachable): + return reachable + return reachable + + +def shape_type_conversion(fn): + """Decorator that handles tuple/TensorShape conversion. + + Used in `compute_output_shape` and `build`. + + Arguments: + fn: function to wrap. + + Returns: + Wrapped function. + """ + + def wrapper(instance, input_shape): + if input_shape is not None: + if isinstance(input_shape, list): + input_shape = [ + tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] + else: + input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) + output_shape = fn(instance, input_shape) + if output_shape is not None: + if isinstance(output_shape, list): + return [tensor_shape.TensorShape(x) for x in output_shape] + return tensor_shape.TensorShape(output_shape) + + return wrapper diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py new file mode 100644 index 0000000000..8007df4622 --- /dev/null +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +# pylint: disable=g-import-not-at-top +"""Utilities related to model visualization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from tensorflow.python.util.tf_export import tf_export + + +try: + # pydot-ng is a fork of pydot that is better maintained. + import pydot_ng as pydot +except ImportError: + # pydotplus is an improved version of pydot + try: + import pydotplus as pydot + except ImportError: + # Fall back on pydot if necessary. + try: + import pydot + except ImportError: + pydot = None + + +def _check_pydot(): + try: + # Attempt to create an image of a blank graph + # to check the pydot/graphviz installation. + pydot.Dot.create(pydot.Dot()) + except Exception: + # pydot raises a generic Exception here, + # so no specific class can be caught. + raise ImportError('Failed to import pydot. You must install pydot' + ' and graphviz for `pydotprint` to work.') + + +def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): + """Convert a Keras model to dot format. + + Arguments: + model: A Keras model instance. + show_shapes: whether to display shape information. + show_layer_names: whether to display layer names. + rankdir: `rankdir` argument passed to PyDot, + a string specifying the format of the plot: + 'TB' creates a vertical plot; + 'LR' creates a horizontal plot. + + Returns: + A `pydot.Dot` instance representing the Keras model. + """ + from tensorflow.python.keras.layers.wrappers import Wrapper + from tensorflow.python.keras.models import Sequential + + _check_pydot() + dot = pydot.Dot() + dot.set('rankdir', rankdir) + dot.set('concentrate', True) + dot.set_node_defaults(shape='record') + + if isinstance(model, Sequential): + if not model.built: + model.build() + model = model.model + layers = model.layers + + # Create graph nodes. + for layer in layers: + layer_id = str(id(layer)) + + # Append a wrapped layer's label to node's label, if it exists. + layer_name = layer.name + class_name = layer.__class__.__name__ + if isinstance(layer, Wrapper): + layer_name = '{}({})'.format(layer_name, layer.layer.name) + child_class_name = layer.layer.__class__.__name__ + class_name = '{}({})'.format(class_name, child_class_name) + + # Create node's label. + if show_layer_names: + label = '{}: {}'.format(layer_name, class_name) + else: + label = class_name + + # Rebuild the label as a table including input/output shapes. + if show_shapes: + try: + outputlabels = str(layer.output_shape) + except AttributeError: + outputlabels = 'multiple' + if hasattr(layer, 'input_shape'): + inputlabels = str(layer.input_shape) + elif hasattr(layer, 'input_shapes'): + inputlabels = ', '.join([str(ishape) for ishape in layer.input_shapes]) + else: + inputlabels = 'multiple' + label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels, + outputlabels) + node = pydot.Node(layer_id, label=label) + dot.add_node(node) + + # Connect nodes with edges. + for layer in layers: + layer_id = str(id(layer)) + for i, node in enumerate(layer._inbound_nodes): + node_key = layer.name + '_ib-' + str(i) + if node_key in model._network_nodes: # pylint: disable=protected-access + for inbound_layer in node.inbound_layers: + inbound_layer_id = str(id(inbound_layer)) + layer_id = str(id(layer)) + dot.add_edge(pydot.Edge(inbound_layer_id, layer_id)) + return dot + + +@tf_export('keras.utils.plot_model') +def plot_model(model, + to_file='model.png', + show_shapes=False, + show_layer_names=True, + rankdir='TB'): + """Converts a Keras model to dot format and save to a file. + + Arguments: + model: A Keras model instance + to_file: File name of the plot image. + show_shapes: whether to display shape information. + show_layer_names: whether to display layer names. + rankdir: `rankdir` argument passed to PyDot, + a string specifying the format of the plot: + 'TB' creates a vertical plot; + 'LR' creates a horizontal plot. + """ + dot = model_to_dot(model, show_shapes, show_layer_names, rankdir) + _, extension = os.path.splitext(to_file) + if not extension: + extension = 'png' + else: + extension = extension[1:] + dot.write(to_file, format=extension) |