aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-03-09 14:40:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 14:44:24 -0800
commit8044288df687b07004624275295b93dca07b267b (patch)
tree132c842e882a4e01fefb5f26e4ecadb072fd1d5b
parent4ffc1043866d688023ed2942bb8b02e803c42891 (diff)
Part of the update of tf.keras to the Keras 2.1.5 API.
PiperOrigin-RevId: 188540513
-rw-r--r--tensorflow/python/keras/_impl/keras/__init__.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py15
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving.py243
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving_test.py86
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py51
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_arrays.py11
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_generator.py75
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py83
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers.py24
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/data_utils.py23
11 files changed, 479 insertions, 143 deletions
diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py
index b63907b2e6..53f5d31e9c 100644
--- a/tensorflow/python/keras/_impl/keras/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/__init__.py
@@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.models import Sequential
-__version__ = '2.1.4-tf'
+__version__ = '2.1.5-tf'
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 688dc070e6..04866fbe0f 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -423,8 +423,9 @@ def get_session():
A TensorFlow session.
"""
global _SESSION
- if ops.get_default_session() is not None:
- session = ops.get_default_session()
+ default_session = ops.get_default_session()
+ if default_session is not None:
+ session = default_session
else:
if _SESSION is None:
if not os.environ.get('OMP_NUM_THREADS'):
@@ -495,7 +496,7 @@ def _is_current_explicit_device(device_type):
"""
device_type = device_type.upper()
if device_type not in ['CPU', 'GPU']:
- raise ValueError('device_type should be either "CPU" or "GPU".')
+ raise ValueError('`device_type` should be either "CPU" or "GPU".')
device = _get_current_tf_device()
return device is not None and device.device_type == device_type.upper()
@@ -3514,7 +3515,7 @@ def l2_normalize(x, axis=None):
Returns:
A tensor.
"""
- return nn.l2_normalize(x, dim=axis)
+ return nn.l2_normalize(x, axis=axis)
@tf_export('keras.backend.in_top_k')
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index bde16cdeb0..bf82390438 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -406,6 +406,7 @@ class Network(base_layer.Layer):
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
+ If `name` and `index` are both provided, `index` will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Arguments:
@@ -437,7 +438,7 @@ class Network(base_layer.Layer):
@property
def updates(self):
- """Retrieve the network's updates.
+ """Retrieves the network's updates.
Will only include updates that are either
unconditional, or conditional on inputs to this model
@@ -517,7 +518,7 @@ class Network(base_layer.Layer):
@property
def losses(self):
- """Retrieve the network's losses.
+ """Retrieves the network's losses.
Will only include losses that are either
unconditional, or conditional on inputs to this model
@@ -600,7 +601,7 @@ class Network(base_layer.Layer):
return specs
def call(self, inputs, training=None, mask=None):
- """Call the model on new inputs.
+ """Calls the model on new inputs.
In this case `call` just reapplies
all ops in the graph to the new inputs
@@ -1030,7 +1031,7 @@ class Network(base_layer.Layer):
layer(input_tensors, **kwargs)
def process_layer(layer_data):
- """Deserialize a layer, then call it on appropriate inputs.
+ """Deserializes a layer, then call it on appropriate inputs.
Arguments:
layer_data: layer config dict.
@@ -1087,7 +1088,7 @@ class Network(base_layer.Layer):
return cls(inputs=input_tensors, outputs=output_tensors, name=name)
def save(self, filepath, overwrite=True, include_optimizer=True):
- """Save the model to a single HDF5 file.
+ """Saves the model to a single HDF5 file.
The savefile includes:
- The model architecture, allowing to re-instantiate the model.
@@ -1193,7 +1194,7 @@ class Network(base_layer.Layer):
saving.load_weights_from_hdf5_group(f, self.layers)
def _updated_config(self):
- """Util hared between different serialization methods.
+ """Util shared between different serialization methods.
Returns:
Model config with Keras version information added.
@@ -1333,7 +1334,7 @@ def _make_node_key(layer_name, node_index):
def _map_graph_network(inputs, outputs):
- """Validate a network's topology and gather its layers and nodes.
+ """Validates a network's topology and gather its layers and nodes.
Arguments:
inputs: List of input tensors.
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py
index 52522e6935..2ad06ca4fd 100644
--- a/tensorflow/python/keras/_impl/keras/engine/saving.py
+++ b/tensorflow/python/keras/_impl/keras/engine/saving.py
@@ -35,6 +35,7 @@ from tensorflow.python.util.tf_export import tf_export
# pylint: disable=g-import-not-at-top
try:
import h5py
+ HDF5_OBJECT_HEADER_LIMIT = 64512
except ImportError:
h5py = None
@@ -47,7 +48,7 @@ except ImportError:
@tf_export('keras.models.save_model')
def save_model(model, filepath, overwrite=True, include_optimizer=True):
- """Save a model to a HDF5 file.
+ """Saves a model to a HDF5 file.
The saved model contains:
- the model's configuration (topology)
@@ -74,7 +75,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
raise ImportError('`save_model` requires h5py.')
def get_json_type(obj):
- """Serialize any object to a JSON-serializable structure.
+ """Serializes any object to a JSON-serializable structure.
Arguments:
obj: the object to serialize
@@ -358,34 +359,6 @@ def model_from_json(json_string, custom_objects=None):
return deserialize(config, custom_objects=custom_objects)
-def save_weights_to_hdf5_group(f, layers):
- from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
-
- f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
- f.attrs['backend'] = K.backend().encode('utf8')
- f.attrs['keras_version'] = str(keras_version).encode('utf8')
-
- for layer in layers:
- g = f.create_group(layer.name)
- symbolic_weights = layer.weights
- weight_values = K.batch_get_value(symbolic_weights)
- weight_names = []
- for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
- if hasattr(w, 'name') and w.name:
- name = str(w.name)
- else:
- name = 'param_' + str(i)
- weight_names.append(name.encode('utf8'))
- g.attrs['weight_names'] = weight_names
- for name, val in zip(weight_names, weight_values):
- param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
- if not val.shape:
- # scalar
- param_dset[()] = val
- else:
- param_dset[:] = val
-
-
def preprocess_weights_for_loading(layer,
weights,
original_keras_version=None,
@@ -549,9 +522,140 @@ def preprocess_weights_for_loading(layer,
# split the bias into half and merge
weights[2] = bias[:units * 4] + bias[units * 4:]
+ return convert_rnn_weights(layer, weights)
+
+
+def convert_rnn_weights(layer, weights):
+ """Converts weights for RNN layers between native and CuDNN format.
+
+ Input kernels for each gate are transposed and converted between Fortran
+ and C layout, recurrent kernels are transposed. For LSTM biases are summed/
+ split in half, for GRU biases are reshaped.
+
+ Weights can be converted in both directions between `LSTM` and`CuDNNSLTM`
+ and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not
+ compatible with `CuDNNGRU`.
+
+ For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made.
+
+ Arguments:
+ layer: Target layer instance.
+ weights: List of source weights values (input kernels, recurrent
+ kernels, [biases]) (Numpy arrays).
+
+ Returns:
+ A list of converted weights values (Numpy arrays).
+
+ Raises:
+ ValueError: for incompatible GRU layer/weights or incompatible biases
+ """
+
+ def transform_kernels(kernels, func, n_gates):
+ """Transforms kernel for each gate separately using given function.
+
+ Arguments:
+ kernels: Stacked array of kernels for individual gates.
+ func: Function applied to kernel of each gate.
+ n_gates: Number of gates (4 for LSTM, 3 for GRU).
+ Returns:
+ Stacked array of transformed kernels.
+ """
+ return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)])
+
+ def transpose_input(from_cudnn):
+ """Makes a function that transforms input kernels from/to CuDNN format.
+
+ It keeps the shape, but changes between the layout (Fortran/C). Eg.:
+
+ ```
+ Keras CuDNN
+ [[0, 1, 2], <---> [[0, 2, 4],
+ [3, 4, 5]] [1, 3, 5]]
+ ```
+
+ It can be passed to `transform_kernels()`.
+
+ Arguments:
+ from_cudnn: `True` if source weights are in CuDNN format, `False`
+ if they're in plain Keras format.
+ Returns:
+ Function that converts input kernel to the other format.
+ """
+ order = 'F' if from_cudnn else 'C'
+
+ def transform(kernel):
+ return kernel.T.reshape(kernel.shape, order=order)
+
+ return transform
+
+ target_class = layer.__class__.__name__
+
+ # convert the weights between CuDNNLSTM and LSTM
+ if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3:
+ # determine if we're loading a CuDNNLSTM layer
+ # from the number of bias weights:
+ # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
+ # if there's no bias weight in the file, skip this conversion
+ units = weights[1].shape[0]
+ bias_shape = weights[2].shape
+ n_gates = 4
+
+ if bias_shape == (2 * units * n_gates,):
+ source = 'CuDNNLSTM'
+ elif bias_shape == (units * n_gates,):
+ source = 'LSTM'
+ else:
+ raise ValueError('Invalid bias shape: ' + str(bias_shape))
+
+ def convert_lstm_weights(weights, from_cudnn=True):
+ # Transpose (and reshape) input and recurrent kernels.
+ kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
+ n_gates)
+ recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
+ if from_cudnn: # Merge input and recurrent biases into a single set.
+ biases = np.sum(np.split(weights[2], 2, axis=0), axis=0)
+ else:
+ # Split single set of biases evenly to two sets.
+ biases = np.tile(0.5 * weights[2], 2)
+ return [kernels, recurrent_kernels, biases]
+
+ if source != target_class:
+ weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM')
+
+ # TODO(fchollet): add feature after GRU is refactored:
+ # convert the weights between `CuDNNGRU` and `GRU(reset_after=True)`
return weights
+def save_weights_to_hdf5_group(f, layers):
+ from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
+
+ save_attributes_to_hdf5_group(
+ f, 'layer_names', [layer.name.encode('utf8') for layer in layers])
+ f.attrs['backend'] = K.backend().encode('utf8')
+ f.attrs['keras_version'] = str(keras_version).encode('utf8')
+
+ for layer in layers:
+ g = f.create_group(layer.name)
+ symbolic_weights = layer.weights
+ weight_values = K.batch_get_value(symbolic_weights)
+ weight_names = []
+ for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
+ if hasattr(w, 'name') and w.name:
+ name = str(w.name)
+ else:
+ name = 'param_' + str(i)
+ weight_names.append(name.encode('utf8'))
+ save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
+ for name, val in zip(weight_names, weight_values):
+ param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
+ if not val.shape:
+ # scalar
+ param_dset[()] = val
+ else:
+ param_dset[:] = val
+
+
def load_weights_from_hdf5_group(f, layers):
"""Implements topological (order-based) weight loading.
@@ -578,11 +682,11 @@ def load_weights_from_hdf5_group(f, layers):
if weights:
filtered_layers.append(layer)
- layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
+ layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
filtered_layer_names = []
for name in layer_names:
g = f[name]
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
+ weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
if weight_names:
filtered_layer_names.append(name)
layer_names = filtered_layer_names
@@ -597,7 +701,7 @@ def load_weights_from_hdf5_group(f, layers):
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
+ weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
weight_values = [g[weight_name] for weight_name in weight_names]
layer = filtered_layers[k]
symbolic_weights = layer.weights
@@ -640,7 +744,7 @@ def load_weights_from_hdf5_group_by_name(f, layers):
original_backend = None
# New file format.
- layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
+ layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
# Reverse index of layer name to list of layers with name.
index = {}
@@ -653,7 +757,7 @@ def load_weights_from_hdf5_group_by_name(f, layers):
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
+ weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
weight_values = [g[weight_name] for weight_name in weight_names]
for layer in index.get(name, []):
@@ -669,3 +773,72 @@ def load_weights_from_hdf5_group_by_name(f, layers):
for i in range(len(weight_values)):
weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
K.batch_set_value(weight_value_tuples)
+
+
+def save_attributes_to_hdf5_group(group, name, data):
+ """Saves attributes (data) of the specified name into the HDF5 group.
+
+ This method deals with an inherent problem of HDF5 file which is not
+ able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
+
+ Arguments:
+ group: A pointer to a HDF5 group.
+ name: A name of the attributes to save.
+ data: Attributes data to store.
+
+ Raises:
+ RuntimeError: If any single attribute is too large to be saved.
+ """
+ # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
+ # because in that case even chunking the array would not make the saving
+ # possible.
+ bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
+
+ # Expecting this to never be true.
+ if bad_attributes:
+ raise RuntimeError('The following attributes cannot be saved to HDF5 '
+ 'file because they are larger than %d bytes: %s' %
+ (HDF5_OBJECT_HEADER_LIMIT,
+ ', '.join([x for x in bad_attributes])))
+
+ data_npy = np.asarray(data)
+
+ num_chunks = 1
+ chunked_data = np.array_split(data_npy, num_chunks)
+
+ # This will never loop forever thanks to the test above.
+ while any([x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data]):
+ num_chunks += 1
+ chunked_data = np.array_split(data_npy, num_chunks)
+
+ if num_chunks > 1:
+ for chunk_id, chunk_data in enumerate(chunked_data):
+ group.attrs['%s%d' % (name, chunk_id)] = chunk_data
+ else:
+ group.attrs[name] = data
+
+
+def load_attributes_from_hdf5_group(group, name):
+ """Loads attributes of the specified name from the HDF5 group.
+
+ This method deals with an inherent problem
+ of HDF5 file which is not able to store
+ data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
+
+ Arguments:
+ group: A pointer to a HDF5 group.
+ name: A name of the attributes to load.
+
+ Returns:
+ data: Attributes data.
+ """
+ if name in group.attrs:
+ data = [n.decode('utf8') for n in group.attrs[name]]
+ else:
+ data = []
+ chunk_id = 0
+ while '%s%d' % (name, chunk_id) in group.attrs:
+ data.extend(
+ [n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]])
+ chunk_id += 1
+ return data
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
index bdb17641b0..4a18cc2e11 100644
--- a/tensorflow/python/keras/_impl/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
@@ -370,6 +370,92 @@ class TestWholeModelSaving(test.TestCase):
self.assertAllClose(mean, model.layers[1].arguments['mu'])
self.assertAllClose(std, model.layers[1].arguments['std'])
+ def test_saving_model_with_long_layer_names(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ # This layer name will make the `layers_name` HDF5 attribute blow
+ # out of proportion. Note that it fits into the internal HDF5
+ # attribute memory limit on its own but because h5py converts
+ # the list of layer names into numpy array, which uses the same
+ # amout of memory for every item, it increases the memory
+ # requirements substantially.
+ x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))
+ f = x
+ for i in range(4):
+ f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
+ model = keras.Model(inputs=[x], outputs=[f])
+ model.compile(loss='mse', optimizer='adam', metrics=['acc'])
+
+ x = np.random.random((1, 2))
+ y = np.random.random((1, 2))
+ model.train_on_batch(x, y)
+ out = model.predict(x)
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+
+ # Check that the HDF5 files contains chunked array
+ # of layer names.
+ with h5py.File(fname, 'r') as h5file:
+ num_names_arrays = len([attr for attr in h5file['model_weights'].attrs
+ if attr.startswith('layer_names')])
+ # The chunking of layer names array should have happend.
+ self.assertGreater(num_names_arrays, 0)
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Cleanup
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_model_with_long_weights_names(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ x = keras.Input(shape=(2,), name='nested_model_input')
+ f = x
+ for i in range(4):
+ f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)
+ # This layer name will make the `weights_name`
+ # HDF5 attribute blow out of proportion.
+ f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**15)))(f)
+ nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')
+
+ x = keras.Input(shape=(2,), name='outer_model_input')
+ f = nested_model(x)
+ f = keras.layers.Dense(2, name='outer_model_output')(f)
+
+ model = keras.Model(inputs=[x], outputs=[f])
+ model.compile(loss='mse', optimizer='adam', metrics=['acc'])
+
+ x = np.random.random((1, 2))
+ y = np.random.random((1, 2))
+ model.train_on_batch(x, y)
+ out = model.predict(x)
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+
+ # Check that the HDF5 files contains chunked array
+ # of weight names.
+ with h5py.File(fname, 'r') as h5file:
+ num_weight_arrays = len(
+ [attr for attr in h5file['model_weights']['nested_model'].attrs
+ if attr.startswith('weight_names')])
+ # The chunking of layer names array should have happend.
+ self.assertGreater(num_weight_arrays, 0)
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Cleanup
+ os.close(fd)
+ os.remove(fname)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 8b82c0b313..57506f9aff 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -1542,20 +1542,19 @@ class Model(Network):
max_queue_size: Integer. Maximum size for the generator queue.
If unspecified, `max_queue_size` will default to 10.
workers: Integer. Maximum number of processes to spin up
- when using process based threading.
+ when using process-based threading.
If unspecified, `workers` will default to 1. If 0, will
execute the generator on the main thread.
- use_multiprocessing: Boolean. If True, use process based threading.
- If unspecified, `workers` will default to False.
- Note that because
- this implementation relies on multiprocessing,
- you should not pass
- non picklable arguments to the generator
- as they can't be passed
- easily to children processes.
- shuffle: Whether to shuffle the order of the batches at
+ use_multiprocessing: Boolean.
+ If `True`, use process-based threading.
+ If unspecified, `use_multiprocessing` will default to `False`.
+ Note that because this implementation relies on multiprocessing,
+ you should not pass non-picklable arguments to the generator
+ as they can't be passed easily to children processes.
+ shuffle: Boolean. Whether to shuffle the order of the batches at
the beginning of each epoch. Only used with instances
- of `Sequence` (keras.utils.Sequence).
+ of `Sequence` (`keras.utils.Sequence`).
+ Has no effect when `steps_per_epoch` is not `None`.
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
@@ -1625,16 +1624,15 @@ class Model(Network):
the `len(generator)` as a number of steps.
max_queue_size: maximum size for the generator queue
workers: Integer. Maximum number of processes to spin up
- when using process based threading.
+ when using process-based threading.
If unspecified, `workers` will default to 1. If 0, will
execute the generator on the main thread.
- use_multiprocessing: if True, use process based threading.
- Note that because
- this implementation relies on multiprocessing,
- you should not pass
- non picklable arguments to the generator
- as they can't be passed
- easily to children processes.
+ use_multiprocessing: Boolean.
+ If `True`, use process-based threading.
+ If unspecified, `use_multiprocessing` will default to `False`.
+ Note that because this implementation relies on multiprocessing,
+ you should not pass non-picklable arguments to the generator
+ as they can't be passed easily to children processes.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1684,16 +1682,15 @@ class Model(Network):
the `len(generator)` as a number of steps.
max_queue_size: Maximum size for the generator queue.
workers: Integer. Maximum number of processes to spin up
- when using process based threading.
+ when using process-based threading.
If unspecified, `workers` will default to 1. If 0, will
execute the generator on the main thread.
- use_multiprocessing: If `True`, use process based threading.
- Note that because
- this implementation relies on multiprocessing,
- you should not pass
- non picklable arguments to the generator
- as they can't be passed
- easily to children processes.
+ use_multiprocessing: Boolean.
+ If `True`, use process-based threading.
+ If unspecified, `use_multiprocessing` will default to `False`.
+ Note that because this implementation relies on multiprocessing,
+ you should not pass non-picklable arguments to the generator
+ as they can't be passed easily to children processes.
verbose: verbosity mode, 0 or 1.
Returns:
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
index 9291ef5fe6..18116e3a14 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
@@ -298,20 +298,13 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None):
else:
ins = inputs
- if hasattr(model, 'metrics'):
- for m in model.metrics:
- if isinstance(m, Layer):
- m.reset_states()
-
num_samples = training_utils.check_num_samples(
inputs, batch_size, steps, 'steps')
if verbose == 1:
if steps is not None:
- progbar = Progbar(target=steps,
- stateful_metrics=model.stateful_metric_names)
+ progbar = Progbar(target=steps)
else:
- progbar = Progbar(target=num_samples,
- stateful_metrics=model.stateful_metric_names)
+ progbar = Progbar(target=num_samples)
indices_for_conversion_to_dense = []
for i in range(len(model._feed_inputs)):
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_generator.py b/tensorflow/python/keras/_impl/keras/engine/training_generator.py
index 4af62c85d5..58b5bc39c1 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_generator.py
@@ -112,42 +112,25 @@ def fit_generator(model,
val_enqueuer = None
try:
- if do_validation:
- if val_gen:
- if workers > 0:
- if isinstance(validation_data, Sequence):
- val_enqueuer = OrderedEnqueuer(
- validation_data, use_multiprocessing=use_multiprocessing)
- if validation_steps is None:
- validation_steps = len(validation_data)
- else:
- val_enqueuer = GeneratorEnqueuer(
- validation_data,
- use_multiprocessing=use_multiprocessing,
- wait_time=wait_time)
- val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
- validation_generator = val_enqueuer.get()
- else:
- validation_generator = validation_data
+ if do_validation and not val_gen:
+ # Prepare data for validation
+ if len(validation_data) == 2:
+ val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
+ val_sample_weight = None
+ elif len(validation_data) == 3:
+ val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
else:
- if len(validation_data) == 2:
- val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
- val_sample_weight = None
- elif len(validation_data) == 3:
- val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
- else:
- raise ValueError(
- '`validation_data` should be a tuple '
- '`(val_x, val_y, val_sample_weight)` '
- 'or `(val_x, val_y)`. Found: ' + str(validation_data))
- val_x, val_y, val_sample_weights = model._standardize_user_data(
- val_x, val_y, val_sample_weight)
- val_data = val_x + val_y + val_sample_weights
- if model.uses_learning_phase and not isinstance(
- K.learning_phase(), int):
- val_data += [0]
- for cbk in callbacks:
- cbk.validation_data = val_data
+ raise ValueError(
+ '`validation_data` should be a tuple '
+ '`(val_x, val_y, val_sample_weight)` '
+ 'or `(val_x, val_y)`. Found: ' + str(validation_data))
+ val_x, val_y, val_sample_weights = model._standardize_user_data(
+ val_x, val_y, val_sample_weight)
+ val_data = val_x + val_y + val_sample_weights
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ val_data += [0.]
+ for cbk in callbacks:
+ cbk.validation_data = val_data
if workers > 0:
if is_sequence:
@@ -163,7 +146,10 @@ def fit_generator(model,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
- output_generator = generator
+ if is_sequence:
+ output_generator = iter(generator)
+ else:
+ output_generator = generator
callback_model.stop_training = False
# Construct epoch logs.
@@ -218,7 +204,12 @@ def fit_generator(model,
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = evaluate_generator(
- model, validation_generator, validation_steps, workers=0)
+ model,
+ validation_data,
+ validation_steps,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ max_queue_size=max_queue_size)
else:
# No need for try/except because
# data has already been validated.
@@ -297,7 +288,10 @@ def evaluate_generator(model,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
- output_generator = generator
+ if is_sequence:
+ output_generator = iter(generator)
+ else:
+ output_generator = generator
while steps_done < steps:
generator_output = next(output_generator)
@@ -387,7 +381,10 @@ def predict_generator(model,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
- output_generator = generator
+ if is_sequence:
+ output_generator = iter(generator)
+ else:
+ output_generator = generator
if verbose == 1:
progbar = Progbar(target=steps)
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 38ba0f0eae..fd91dbba52 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -340,20 +340,21 @@ class TrainingTest(test.TestCase):
if scipy_sparse is None:
return
- test_inputs = [
- scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
- test_outputs = [
- scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
- in1 = keras.layers.Input(shape=(3,))
- in2 = keras.layers.Input(shape=(3,))
- out1 = keras.layers.Dropout(0.5, name='dropout')(in1)
- out2 = keras.layers.Dense(4, name='dense_1')(in2)
- model = keras.Model([in1, in2], [out1, out2])
- model.predict(test_inputs, batch_size=2)
- model.compile('rmsprop', 'mse')
- model.fit(test_inputs, test_outputs,
- epochs=1, batch_size=2, validation_split=0.5)
- model.evaluate(test_inputs, test_outputs, batch_size=2)
+ with self.test_session():
+ test_inputs = [
+ scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
+ test_outputs = [
+ scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
+ in1 = keras.layers.Input(shape=(3,))
+ in2 = keras.layers.Input(shape=(3,))
+ out1 = keras.layers.Dropout(0.5, name='dropout')(in1)
+ out2 = keras.layers.Dense(4, name='dense_1')(in2)
+ model = keras.Model([in1, in2], [out1, out2])
+ model.predict(test_inputs, batch_size=2)
+ model.compile('rmsprop', 'mse')
+ model.fit(test_inputs, test_outputs,
+ epochs=1, batch_size=2, validation_split=0.5)
+ model.evaluate(test_inputs, test_outputs, batch_size=2)
def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4))
@@ -876,9 +877,9 @@ class TestGeneratorMethods(test.TestCase):
def custom_generator():
batch_size = 10
- n_samples = 50
+ num_samples = 50
while True:
- batch_index = np.random.randint(0, n_samples - batch_size)
+ batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
@@ -957,9 +958,9 @@ class TestGeneratorMethods(test.TestCase):
def custom_generator():
batch_size = 10
- n_samples = 50
+ num_samples = 50
while True:
- batch_index = np.random.randint(0, n_samples - batch_size)
+ batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
@@ -1033,6 +1034,52 @@ class TestGeneratorMethods(test.TestCase):
max_queue_size=10,
use_multiprocessing=False)
+ def test_training_with_sequences(self):
+
+ class DummySequence(keras.utils.Sequence):
+
+ def __getitem__(self, idx):
+ return np.zeros([10, 2]), np.ones([10])
+
+ def __len__(self):
+ return 10
+
+ arr_data = np.random.random((50, 2))
+ arr_labels = np.random.random((50,))
+ arr_sample_weights = np.random.random((50,))
+
+ def custom_generator():
+ batch_size = 10
+ num_samples = 50
+ while True:
+ batch_index = np.random.randint(0, num_samples - batch_size)
+ start = batch_index
+ end = start + batch_size
+ x = arr_data[start: end]
+ y = arr_labels[start: end]
+ w = arr_sample_weights[start: end]
+ yield x, y, w
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(2,)))
+ model.compile(loss='mse', optimizer='sgd')
+
+ model.fit_generator(DummySequence(),
+ steps_per_epoch=10,
+ validation_data=custom_generator(),
+ validation_steps=1,
+ max_queue_size=10,
+ workers=0,
+ use_multiprocessing=True)
+ model.fit_generator(DummySequence(),
+ steps_per_epoch=10,
+ validation_data=custom_generator(),
+ validation_steps=1,
+ max_queue_size=10,
+ workers=0,
+ use_multiprocessing=False)
+
class TestTrainingUtils(test.TestCase):
diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py
index 6520128c5b..b715d722b9 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers.py
@@ -95,7 +95,26 @@ class Optimizer(object):
raise NotImplementedError
def get_gradients(self, loss, params):
+ """Returns gradients of `loss` with respect to `params`.
+
+ Arguments:
+ loss: Loss tensor.
+ params: List of variables.
+
+ Returns:
+ List of gradient tensors.
+
+ Raises:
+ ValueError: In case any gradient cannot be computed (e.g. if gradient
+ function not implemented).
+ """
grads = K.gradients(loss, params)
+ if None in grads:
+ raise ValueError('An operation has `None` for gradient. '
+ 'Please make sure that all of your ops have a '
+ 'gradient defined (i.e. are differentiable). '
+ 'Common ops without gradient: '
+ 'K.argmax, K.round, K.eval.')
if hasattr(self, 'clipnorm') and self.clipnorm > 0:
norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads]))
grads = [clip_norm(g, self.clipnorm, norm) for g in grads]
@@ -120,6 +139,11 @@ class Optimizer(object):
ValueError: in case of incompatible weight shapes.
"""
params = self.weights
+ if len(params) != len(weights):
+ raise ValueError(
+ 'Length of the specified weight list (' + str(len(weights)) +
+ ') does not match the number of weights '
+ 'of the optimizer (' + str(len(params)) + ')')
weight_value_tuples = []
param_values = K.batch_get_value(params)
for pv, p, w in zip(param_values, params, weights):
diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
index e87c8f48ef..4c49544c6a 100644
--- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
@@ -393,6 +393,16 @@ class Sequence(object):
"""
pass
+ def __iter__(self):
+ """Creates an infinite generator that iterate over the Sequence.
+
+ Yields:
+ Sequence items.
+ """
+ while True:
+ for item in (self[i] for i in range(len(self))):
+ yield item
+
# Global variables to be shared across processes
_SHARED_SEQUENCES = {}
@@ -400,6 +410,11 @@ _SHARED_SEQUENCES = {}
_SEQUENCE_COUNTER = None
+def init_pool(seqs):
+ global _SHARED_SEQUENCES
+ _SHARED_SEQUENCES = seqs
+
+
def get_index(uid, i):
"""Get the value from the Sequence `uid` at index `i`.
@@ -532,9 +547,11 @@ class OrderedEnqueuer(SequenceEnqueuer):
(when full, workers could block on `put()`)
"""
if self.use_multiprocessing:
- self.executor_fn = lambda: multiprocessing.Pool(workers)
+ self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
+ workers, initializer=init_pool, initargs=(seqs,))
else:
- self.executor_fn = lambda: ThreadPool(workers)
+ # We do not need the init since it's threads.
+ self.executor_fn = lambda _: ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
self.stop_signal = threading.Event()
@@ -557,7 +574,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
if self.shuffle:
random.shuffle(sequence)
- with closing(self.executor_fn()) as executor:
+ with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
for i in sequence:
if self.stop_signal.is_set():
return