diff options
author | 2018-03-09 14:40:18 -0800 | |
---|---|---|
committer | 2018-03-09 14:44:24 -0800 | |
commit | 8044288df687b07004624275295b93dca07b267b (patch) | |
tree | 132c842e882a4e01fefb5f26e4ecadb072fd1d5b | |
parent | 4ffc1043866d688023ed2942bb8b02e803c42891 (diff) |
Part of the update of tf.keras to the Keras 2.1.5 API.
PiperOrigin-RevId: 188540513
11 files changed, 479 insertions, 143 deletions
diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py index b63907b2e6..53f5d31e9c 100644 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ b/tensorflow/python/keras/_impl/keras/__init__.py @@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.models import Sequential -__version__ = '2.1.4-tf' +__version__ = '2.1.5-tf' diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 688dc070e6..04866fbe0f 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -423,8 +423,9 @@ def get_session(): A TensorFlow session. """ global _SESSION - if ops.get_default_session() is not None: - session = ops.get_default_session() + default_session = ops.get_default_session() + if default_session is not None: + session = default_session else: if _SESSION is None: if not os.environ.get('OMP_NUM_THREADS'): @@ -495,7 +496,7 @@ def _is_current_explicit_device(device_type): """ device_type = device_type.upper() if device_type not in ['CPU', 'GPU']: - raise ValueError('device_type should be either "CPU" or "GPU".') + raise ValueError('`device_type` should be either "CPU" or "GPU".') device = _get_current_tf_device() return device is not None and device.device_type == device_type.upper() @@ -3514,7 +3515,7 @@ def l2_normalize(x, axis=None): Returns: A tensor. """ - return nn.l2_normalize(x, dim=axis) + return nn.l2_normalize(x, axis=axis) @tf_export('keras.backend.in_top_k') diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py index bde16cdeb0..bf82390438 100644 --- a/tensorflow/python/keras/_impl/keras/engine/network.py +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -406,6 +406,7 @@ class Network(base_layer.Layer): def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. + If `name` and `index` are both provided, `index` will take precedence. Indices are based on order of horizontal graph traversal (bottom-up). Arguments: @@ -437,7 +438,7 @@ class Network(base_layer.Layer): @property def updates(self): - """Retrieve the network's updates. + """Retrieves the network's updates. Will only include updates that are either unconditional, or conditional on inputs to this model @@ -517,7 +518,7 @@ class Network(base_layer.Layer): @property def losses(self): - """Retrieve the network's losses. + """Retrieves the network's losses. Will only include losses that are either unconditional, or conditional on inputs to this model @@ -600,7 +601,7 @@ class Network(base_layer.Layer): return specs def call(self, inputs, training=None, mask=None): - """Call the model on new inputs. + """Calls the model on new inputs. In this case `call` just reapplies all ops in the graph to the new inputs @@ -1030,7 +1031,7 @@ class Network(base_layer.Layer): layer(input_tensors, **kwargs) def process_layer(layer_data): - """Deserialize a layer, then call it on appropriate inputs. + """Deserializes a layer, then call it on appropriate inputs. Arguments: layer_data: layer config dict. @@ -1087,7 +1088,7 @@ class Network(base_layer.Layer): return cls(inputs=input_tensors, outputs=output_tensors, name=name) def save(self, filepath, overwrite=True, include_optimizer=True): - """Save the model to a single HDF5 file. + """Saves the model to a single HDF5 file. The savefile includes: - The model architecture, allowing to re-instantiate the model. @@ -1193,7 +1194,7 @@ class Network(base_layer.Layer): saving.load_weights_from_hdf5_group(f, self.layers) def _updated_config(self): - """Util hared between different serialization methods. + """Util shared between different serialization methods. Returns: Model config with Keras version information added. @@ -1333,7 +1334,7 @@ def _make_node_key(layer_name, node_index): def _map_graph_network(inputs, outputs): - """Validate a network's topology and gather its layers and nodes. + """Validates a network's topology and gather its layers and nodes. Arguments: inputs: List of input tensors. diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py index 52522e6935..2ad06ca4fd 100644 --- a/tensorflow/python/keras/_impl/keras/engine/saving.py +++ b/tensorflow/python/keras/_impl/keras/engine/saving.py @@ -35,6 +35,7 @@ from tensorflow.python.util.tf_export import tf_export # pylint: disable=g-import-not-at-top try: import h5py + HDF5_OBJECT_HEADER_LIMIT = 64512 except ImportError: h5py = None @@ -47,7 +48,7 @@ except ImportError: @tf_export('keras.models.save_model') def save_model(model, filepath, overwrite=True, include_optimizer=True): - """Save a model to a HDF5 file. + """Saves a model to a HDF5 file. The saved model contains: - the model's configuration (topology) @@ -74,7 +75,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): raise ImportError('`save_model` requires h5py.') def get_json_type(obj): - """Serialize any object to a JSON-serializable structure. + """Serializes any object to a JSON-serializable structure. Arguments: obj: the object to serialize @@ -358,34 +359,6 @@ def model_from_json(json_string, custom_objects=None): return deserialize(config, custom_objects=custom_objects) -def save_weights_to_hdf5_group(f, layers): - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top - - f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers] - f.attrs['backend'] = K.backend().encode('utf8') - f.attrs['keras_version'] = str(keras_version).encode('utf8') - - for layer in layers: - g = f.create_group(layer.name) - symbolic_weights = layer.weights - weight_values = K.batch_get_value(symbolic_weights) - weight_names = [] - for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): - if hasattr(w, 'name') and w.name: - name = str(w.name) - else: - name = 'param_' + str(i) - weight_names.append(name.encode('utf8')) - g.attrs['weight_names'] = weight_names - for name, val in zip(weight_names, weight_values): - param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) - if not val.shape: - # scalar - param_dset[()] = val - else: - param_dset[:] = val - - def preprocess_weights_for_loading(layer, weights, original_keras_version=None, @@ -549,9 +522,140 @@ def preprocess_weights_for_loading(layer, # split the bias into half and merge weights[2] = bias[:units * 4] + bias[units * 4:] + return convert_rnn_weights(layer, weights) + + +def convert_rnn_weights(layer, weights): + """Converts weights for RNN layers between native and CuDNN format. + + Input kernels for each gate are transposed and converted between Fortran + and C layout, recurrent kernels are transposed. For LSTM biases are summed/ + split in half, for GRU biases are reshaped. + + Weights can be converted in both directions between `LSTM` and`CuDNNSLTM` + and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not + compatible with `CuDNNGRU`. + + For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made. + + Arguments: + layer: Target layer instance. + weights: List of source weights values (input kernels, recurrent + kernels, [biases]) (Numpy arrays). + + Returns: + A list of converted weights values (Numpy arrays). + + Raises: + ValueError: for incompatible GRU layer/weights or incompatible biases + """ + + def transform_kernels(kernels, func, n_gates): + """Transforms kernel for each gate separately using given function. + + Arguments: + kernels: Stacked array of kernels for individual gates. + func: Function applied to kernel of each gate. + n_gates: Number of gates (4 for LSTM, 3 for GRU). + Returns: + Stacked array of transformed kernels. + """ + return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)]) + + def transpose_input(from_cudnn): + """Makes a function that transforms input kernels from/to CuDNN format. + + It keeps the shape, but changes between the layout (Fortran/C). Eg.: + + ``` + Keras CuDNN + [[0, 1, 2], <---> [[0, 2, 4], + [3, 4, 5]] [1, 3, 5]] + ``` + + It can be passed to `transform_kernels()`. + + Arguments: + from_cudnn: `True` if source weights are in CuDNN format, `False` + if they're in plain Keras format. + Returns: + Function that converts input kernel to the other format. + """ + order = 'F' if from_cudnn else 'C' + + def transform(kernel): + return kernel.T.reshape(kernel.shape, order=order) + + return transform + + target_class = layer.__class__.__name__ + + # convert the weights between CuDNNLSTM and LSTM + if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3: + # determine if we're loading a CuDNNLSTM layer + # from the number of bias weights: + # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) + # if there's no bias weight in the file, skip this conversion + units = weights[1].shape[0] + bias_shape = weights[2].shape + n_gates = 4 + + if bias_shape == (2 * units * n_gates,): + source = 'CuDNNLSTM' + elif bias_shape == (units * n_gates,): + source = 'LSTM' + else: + raise ValueError('Invalid bias shape: ' + str(bias_shape)) + + def convert_lstm_weights(weights, from_cudnn=True): + # Transpose (and reshape) input and recurrent kernels. + kernels = transform_kernels(weights[0], transpose_input(from_cudnn), + n_gates) + recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) + if from_cudnn: # Merge input and recurrent biases into a single set. + biases = np.sum(np.split(weights[2], 2, axis=0), axis=0) + else: + # Split single set of biases evenly to two sets. + biases = np.tile(0.5 * weights[2], 2) + return [kernels, recurrent_kernels, biases] + + if source != target_class: + weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM') + + # TODO(fchollet): add feature after GRU is refactored: + # convert the weights between `CuDNNGRU` and `GRU(reset_after=True)` return weights +def save_weights_to_hdf5_group(f, layers): + from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + + save_attributes_to_hdf5_group( + f, 'layer_names', [layer.name.encode('utf8') for layer in layers]) + f.attrs['backend'] = K.backend().encode('utf8') + f.attrs['keras_version'] = str(keras_version).encode('utf8') + + for layer in layers: + g = f.create_group(layer.name) + symbolic_weights = layer.weights + weight_values = K.batch_get_value(symbolic_weights) + weight_names = [] + for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): + if hasattr(w, 'name') and w.name: + name = str(w.name) + else: + name = 'param_' + str(i) + weight_names.append(name.encode('utf8')) + save_attributes_to_hdf5_group(g, 'weight_names', weight_names) + for name, val in zip(weight_names, weight_values): + param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + + def load_weights_from_hdf5_group(f, layers): """Implements topological (order-based) weight loading. @@ -578,11 +682,11 @@ def load_weights_from_hdf5_group(f, layers): if weights: filtered_layers.append(layer) - layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] + layer_names = load_attributes_from_hdf5_group(f, 'layer_names') filtered_layer_names = [] for name in layer_names: g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + weight_names = load_attributes_from_hdf5_group(g, 'weight_names') if weight_names: filtered_layer_names.append(name) layer_names = filtered_layer_names @@ -597,7 +701,7 @@ def load_weights_from_hdf5_group(f, layers): weight_value_tuples = [] for k, name in enumerate(layer_names): g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + weight_names = load_attributes_from_hdf5_group(g, 'weight_names') weight_values = [g[weight_name] for weight_name in weight_names] layer = filtered_layers[k] symbolic_weights = layer.weights @@ -640,7 +744,7 @@ def load_weights_from_hdf5_group_by_name(f, layers): original_backend = None # New file format. - layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] + layer_names = load_attributes_from_hdf5_group(f, 'layer_names') # Reverse index of layer name to list of layers with name. index = {} @@ -653,7 +757,7 @@ def load_weights_from_hdf5_group_by_name(f, layers): weight_value_tuples = [] for k, name in enumerate(layer_names): g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + weight_names = load_attributes_from_hdf5_group(g, 'weight_names') weight_values = [g[weight_name] for weight_name in weight_names] for layer in index.get(name, []): @@ -669,3 +773,72 @@ def load_weights_from_hdf5_group_by_name(f, layers): for i in range(len(weight_values)): weight_value_tuples.append((symbolic_weights[i], weight_values[i])) K.batch_set_value(weight_value_tuples) + + +def save_attributes_to_hdf5_group(group, name, data): + """Saves attributes (data) of the specified name into the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not + able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes. + + Arguments: + group: A pointer to a HDF5 group. + name: A name of the attributes to save. + data: Attributes data to store. + + Raises: + RuntimeError: If any single attribute is too large to be saved. + """ + # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` + # because in that case even chunking the array would not make the saving + # possible. + bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] + + # Expecting this to never be true. + if bad_attributes: + raise RuntimeError('The following attributes cannot be saved to HDF5 ' + 'file because they are larger than %d bytes: %s' % + (HDF5_OBJECT_HEADER_LIMIT, + ', '.join([x for x in bad_attributes]))) + + data_npy = np.asarray(data) + + num_chunks = 1 + chunked_data = np.array_split(data_npy, num_chunks) + + # This will never loop forever thanks to the test above. + while any([x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data]): + num_chunks += 1 + chunked_data = np.array_split(data_npy, num_chunks) + + if num_chunks > 1: + for chunk_id, chunk_data in enumerate(chunked_data): + group.attrs['%s%d' % (name, chunk_id)] = chunk_data + else: + group.attrs[name] = data + + +def load_attributes_from_hdf5_group(group, name): + """Loads attributes of the specified name from the HDF5 group. + + This method deals with an inherent problem + of HDF5 file which is not able to store + data larger than HDF5_OBJECT_HEADER_LIMIT bytes. + + Arguments: + group: A pointer to a HDF5 group. + name: A name of the attributes to load. + + Returns: + data: Attributes data. + """ + if name in group.attrs: + data = [n.decode('utf8') for n in group.attrs[name]] + else: + data = [] + chunk_id = 0 + while '%s%d' % (name, chunk_id) in group.attrs: + data.extend( + [n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]]) + chunk_id += 1 + return data diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py index bdb17641b0..4a18cc2e11 100644 --- a/tensorflow/python/keras/_impl/keras/engine/saving_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py @@ -370,6 +370,92 @@ class TestWholeModelSaving(test.TestCase): self.assertAllClose(mean, model.layers[1].arguments['mu']) self.assertAllClose(std, model.layers[1].arguments['std']) + def test_saving_model_with_long_layer_names(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + # This layer name will make the `layers_name` HDF5 attribute blow + # out of proportion. Note that it fits into the internal HDF5 + # attribute memory limit on its own but because h5py converts + # the list of layer names into numpy array, which uses the same + # amout of memory for every item, it increases the memory + # requirements substantially. + x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15))) + f = x + for i in range(4): + f = keras.layers.Dense(2, name='dense_%d' % (i,))(f) + model = keras.Model(inputs=[x], outputs=[f]) + model.compile(loss='mse', optimizer='adam', metrics=['acc']) + + x = np.random.random((1, 2)) + y = np.random.random((1, 2)) + model.train_on_batch(x, y) + out = model.predict(x) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + + # Check that the HDF5 files contains chunked array + # of layer names. + with h5py.File(fname, 'r') as h5file: + num_names_arrays = len([attr for attr in h5file['model_weights'].attrs + if attr.startswith('layer_names')]) + # The chunking of layer names array should have happend. + self.assertGreater(num_names_arrays, 0) + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + # Cleanup + os.close(fd) + os.remove(fname) + + def test_saving_model_with_long_weights_names(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + x = keras.Input(shape=(2,), name='nested_model_input') + f = x + for i in range(4): + f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f) + # This layer name will make the `weights_name` + # HDF5 attribute blow out of proportion. + f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**15)))(f) + nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model') + + x = keras.Input(shape=(2,), name='outer_model_input') + f = nested_model(x) + f = keras.layers.Dense(2, name='outer_model_output')(f) + + model = keras.Model(inputs=[x], outputs=[f]) + model.compile(loss='mse', optimizer='adam', metrics=['acc']) + + x = np.random.random((1, 2)) + y = np.random.random((1, 2)) + model.train_on_batch(x, y) + out = model.predict(x) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + + # Check that the HDF5 files contains chunked array + # of weight names. + with h5py.File(fname, 'r') as h5file: + num_weight_arrays = len( + [attr for attr in h5file['model_weights']['nested_model'].attrs + if attr.startswith('weight_names')]) + # The chunking of layer names array should have happend. + self.assertGreater(num_weight_arrays, 0) + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + # Cleanup + os.close(fd) + os.remove(fname) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 8b82c0b313..57506f9aff 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -1542,20 +1542,19 @@ class Model(Network): max_queue_size: Integer. Maximum size for the generator queue. If unspecified, `max_queue_size` will default to 10. workers: Integer. Maximum number of processes to spin up - when using process based threading. + when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: Boolean. If True, use process based threading. - If unspecified, `workers` will default to False. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. - shuffle: Whether to shuffle the order of the batches at + use_multiprocessing: Boolean. + If `True`, use process-based threading. + If unspecified, `use_multiprocessing` will default to `False`. + Note that because this implementation relies on multiprocessing, + you should not pass non-picklable arguments to the generator + as they can't be passed easily to children processes. + shuffle: Boolean. Whether to shuffle the order of the batches at the beginning of each epoch. Only used with instances - of `Sequence` (keras.utils.Sequence). + of `Sequence` (`keras.utils.Sequence`). + Has no effect when `steps_per_epoch` is not `None`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) @@ -1625,16 +1624,15 @@ class Model(Network): the `len(generator)` as a number of steps. max_queue_size: maximum size for the generator queue workers: Integer. Maximum number of processes to spin up - when using process based threading. + when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: if True, use process based threading. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. + use_multiprocessing: Boolean. + If `True`, use process-based threading. + If unspecified, `use_multiprocessing` will default to `False`. + Note that because this implementation relies on multiprocessing, + you should not pass non-picklable arguments to the generator + as they can't be passed easily to children processes. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1684,16 +1682,15 @@ class Model(Network): the `len(generator)` as a number of steps. max_queue_size: Maximum size for the generator queue. workers: Integer. Maximum number of processes to spin up - when using process based threading. + when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: If `True`, use process based threading. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. + use_multiprocessing: Boolean. + If `True`, use process-based threading. + If unspecified, `use_multiprocessing` will default to `False`. + Note that because this implementation relies on multiprocessing, + you should not pass non-picklable arguments to the generator + as they can't be passed easily to children processes. verbose: verbosity mode, 0 or 1. Returns: diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py index 9291ef5fe6..18116e3a14 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py @@ -298,20 +298,13 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): else: ins = inputs - if hasattr(model, 'metrics'): - for m in model.metrics: - if isinstance(m, Layer): - m.reset_states() - num_samples = training_utils.check_num_samples( inputs, batch_size, steps, 'steps') if verbose == 1: if steps is not None: - progbar = Progbar(target=steps, - stateful_metrics=model.stateful_metric_names) + progbar = Progbar(target=steps) else: - progbar = Progbar(target=num_samples, - stateful_metrics=model.stateful_metric_names) + progbar = Progbar(target=num_samples) indices_for_conversion_to_dense = [] for i in range(len(model._feed_inputs)): diff --git a/tensorflow/python/keras/_impl/keras/engine/training_generator.py b/tensorflow/python/keras/_impl/keras/engine/training_generator.py index 4af62c85d5..58b5bc39c1 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_generator.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_generator.py @@ -112,42 +112,25 @@ def fit_generator(model, val_enqueuer = None try: - if do_validation: - if val_gen: - if workers > 0: - if isinstance(validation_data, Sequence): - val_enqueuer = OrderedEnqueuer( - validation_data, use_multiprocessing=use_multiprocessing) - if validation_steps is None: - validation_steps = len(validation_data) - else: - val_enqueuer = GeneratorEnqueuer( - validation_data, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) - val_enqueuer.start(workers=workers, max_queue_size=max_queue_size) - validation_generator = val_enqueuer.get() - else: - validation_generator = validation_data + if do_validation and not val_gen: + # Prepare data for validation + if len(validation_data) == 2: + val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence + val_sample_weight = None + elif len(validation_data) == 3: + val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence else: - if len(validation_data) == 2: - val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence - val_sample_weight = None - elif len(validation_data) == 3: - val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence - else: - raise ValueError( - '`validation_data` should be a tuple ' - '`(val_x, val_y, val_sample_weight)` ' - 'or `(val_x, val_y)`. Found: ' + str(validation_data)) - val_x, val_y, val_sample_weights = model._standardize_user_data( - val_x, val_y, val_sample_weight) - val_data = val_x + val_y + val_sample_weights - if model.uses_learning_phase and not isinstance( - K.learning_phase(), int): - val_data += [0] - for cbk in callbacks: - cbk.validation_data = val_data + raise ValueError( + '`validation_data` should be a tuple ' + '`(val_x, val_y, val_sample_weight)` ' + 'or `(val_x, val_y)`. Found: ' + str(validation_data)) + val_x, val_y, val_sample_weights = model._standardize_user_data( + val_x, val_y, val_sample_weight) + val_data = val_x + val_y + val_sample_weights + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + val_data += [0.] + for cbk in callbacks: + cbk.validation_data = val_data if workers > 0: if is_sequence: @@ -163,7 +146,10 @@ def fit_generator(model, enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: - output_generator = generator + if is_sequence: + output_generator = iter(generator) + else: + output_generator = generator callback_model.stop_training = False # Construct epoch logs. @@ -218,7 +204,12 @@ def fit_generator(model, if steps_done >= steps_per_epoch and do_validation: if val_gen: val_outs = evaluate_generator( - model, validation_generator, validation_steps, workers=0) + model, + validation_data, + validation_steps, + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size) else: # No need for try/except because # data has already been validated. @@ -297,7 +288,10 @@ def evaluate_generator(model, enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: - output_generator = generator + if is_sequence: + output_generator = iter(generator) + else: + output_generator = generator while steps_done < steps: generator_output = next(output_generator) @@ -387,7 +381,10 @@ def predict_generator(model, enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: - output_generator = generator + if is_sequence: + output_generator = iter(generator) + else: + output_generator = generator if verbose == 1: progbar = Progbar(target=steps) diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 38ba0f0eae..fd91dbba52 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -340,20 +340,21 @@ class TrainingTest(test.TestCase): if scipy_sparse is None: return - test_inputs = [ - scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)] - test_outputs = [ - scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)] - in1 = keras.layers.Input(shape=(3,)) - in2 = keras.layers.Input(shape=(3,)) - out1 = keras.layers.Dropout(0.5, name='dropout')(in1) - out2 = keras.layers.Dense(4, name='dense_1')(in2) - model = keras.Model([in1, in2], [out1, out2]) - model.predict(test_inputs, batch_size=2) - model.compile('rmsprop', 'mse') - model.fit(test_inputs, test_outputs, - epochs=1, batch_size=2, validation_split=0.5) - model.evaluate(test_inputs, test_outputs, batch_size=2) + with self.test_session(): + test_inputs = [ + scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)] + test_outputs = [ + scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)] + in1 = keras.layers.Input(shape=(3,)) + in2 = keras.layers.Input(shape=(3,)) + out1 = keras.layers.Dropout(0.5, name='dropout')(in1) + out2 = keras.layers.Dense(4, name='dense_1')(in2) + model = keras.Model([in1, in2], [out1, out2]) + model.predict(test_inputs, batch_size=2) + model.compile('rmsprop', 'mse') + model.fit(test_inputs, test_outputs, + epochs=1, batch_size=2, validation_split=0.5) + model.evaluate(test_inputs, test_outputs, batch_size=2) def test_that_trainable_disables_updates(self): val_a = np.random.random((10, 4)) @@ -876,9 +877,9 @@ class TestGeneratorMethods(test.TestCase): def custom_generator(): batch_size = 10 - n_samples = 50 + num_samples = 50 while True: - batch_index = np.random.randint(0, n_samples - batch_size) + batch_index = np.random.randint(0, num_samples - batch_size) start = batch_index end = start + batch_size x = arr_data[start: end] @@ -957,9 +958,9 @@ class TestGeneratorMethods(test.TestCase): def custom_generator(): batch_size = 10 - n_samples = 50 + num_samples = 50 while True: - batch_index = np.random.randint(0, n_samples - batch_size) + batch_index = np.random.randint(0, num_samples - batch_size) start = batch_index end = start + batch_size x = arr_data[start: end] @@ -1033,6 +1034,52 @@ class TestGeneratorMethods(test.TestCase): max_queue_size=10, use_multiprocessing=False) + def test_training_with_sequences(self): + + class DummySequence(keras.utils.Sequence): + + def __getitem__(self, idx): + return np.zeros([10, 2]), np.ones([10]) + + def __len__(self): + return 10 + + arr_data = np.random.random((50, 2)) + arr_labels = np.random.random((50,)) + arr_sample_weights = np.random.random((50,)) + + def custom_generator(): + batch_size = 10 + num_samples = 50 + while True: + batch_index = np.random.randint(0, num_samples - batch_size) + start = batch_index + end = start + batch_size + x = arr_data[start: end] + y = arr_labels[start: end] + w = arr_sample_weights[start: end] + yield x, y, w + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(2,))) + model.compile(loss='mse', optimizer='sgd') + + model.fit_generator(DummySequence(), + steps_per_epoch=10, + validation_data=custom_generator(), + validation_steps=1, + max_queue_size=10, + workers=0, + use_multiprocessing=True) + model.fit_generator(DummySequence(), + steps_per_epoch=10, + validation_data=custom_generator(), + validation_steps=1, + max_queue_size=10, + workers=0, + use_multiprocessing=False) + class TestTrainingUtils(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index 6520128c5b..b715d722b9 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -95,7 +95,26 @@ class Optimizer(object): raise NotImplementedError def get_gradients(self, loss, params): + """Returns gradients of `loss` with respect to `params`. + + Arguments: + loss: Loss tensor. + params: List of variables. + + Returns: + List of gradient tensors. + + Raises: + ValueError: In case any gradient cannot be computed (e.g. if gradient + function not implemented). + """ grads = K.gradients(loss, params) + if None in grads: + raise ValueError('An operation has `None` for gradient. ' + 'Please make sure that all of your ops have a ' + 'gradient defined (i.e. are differentiable). ' + 'Common ops without gradient: ' + 'K.argmax, K.round, K.eval.') if hasattr(self, 'clipnorm') and self.clipnorm > 0: norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) grads = [clip_norm(g, self.clipnorm, norm) for g in grads] @@ -120,6 +139,11 @@ class Optimizer(object): ValueError: in case of incompatible weight shapes. """ params = self.weights + if len(params) != len(weights): + raise ValueError( + 'Length of the specified weight list (' + str(len(weights)) + + ') does not match the number of weights ' + 'of the optimizer (' + str(len(params)) + ')') weight_value_tuples = [] param_values = K.batch_get_value(params) for pv, p, w in zip(param_values, params, weights): diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py index e87c8f48ef..4c49544c6a 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py @@ -393,6 +393,16 @@ class Sequence(object): """ pass + def __iter__(self): + """Creates an infinite generator that iterate over the Sequence. + + Yields: + Sequence items. + """ + while True: + for item in (self[i] for i in range(len(self))): + yield item + # Global variables to be shared across processes _SHARED_SEQUENCES = {} @@ -400,6 +410,11 @@ _SHARED_SEQUENCES = {} _SEQUENCE_COUNTER = None +def init_pool(seqs): + global _SHARED_SEQUENCES + _SHARED_SEQUENCES = seqs + + def get_index(uid, i): """Get the value from the Sequence `uid` at index `i`. @@ -532,9 +547,11 @@ class OrderedEnqueuer(SequenceEnqueuer): (when full, workers could block on `put()`) """ if self.use_multiprocessing: - self.executor_fn = lambda: multiprocessing.Pool(workers) + self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda + workers, initializer=init_pool, initargs=(seqs,)) else: - self.executor_fn = lambda: ThreadPool(workers) + # We do not need the init since it's threads. + self.executor_fn = lambda _: ThreadPool(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() @@ -557,7 +574,7 @@ class OrderedEnqueuer(SequenceEnqueuer): if self.shuffle: random.shuffle(sequence) - with closing(self.executor_fn()) as executor: + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: for i in sequence: if self.stop_signal.is_set(): return |