diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 282 |
1 files changed, 230 insertions, 52 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index d224dfffdd..fed07c4120 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -20,9 +20,11 @@ from __future__ import print_function import weakref import numpy as np +import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -39,6 +41,7 @@ from tensorflow.python.keras.engine import training_eager from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine.network import Network +from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import math_ops from tensorflow.python.ops import weights_broadcast_ops @@ -206,8 +209,27 @@ class Model(Network): for metric in metrics: metric_fn = training_utils.get_metric_function( metric, output_shape=output_shape, loss_fn=loss_fn) - metric_name = self._get_metric_name( - metric, output_index, weighted=weights is not None) + + if (context.executing_eagerly() and y_true is not None and + y_pred is not None): + # In eager mode, when executing metric_fn during training, we do not + # need to generate unique metric name and add it to the model + # as we have done that during compile already. + prefix = 'weighted_' if weights is not None else '' + suffix = metric_fn.name if hasattr(metric_fn, + 'name') else metric_fn.__name__ + metric_name = prefix + suffix + else: + # Get metric name that is to be added to the model. + metric_name = self._get_metric_name( + metric, output_index, weighted=weights is not None) + # Keep track of metric name. + self.metrics_names.append(metric_name) + + # Keep track of stateful metric attributes (name and metric function). + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: + self.stateful_metric_names.append(metric_name) + self.stateful_metric_functions.append(metric_fn) with K.name_scope(metric_name): # If both outputs and targets are available, call the metric function. @@ -247,16 +269,10 @@ class Model(Network): self.metrics_tensors.append(metric_result) metric_results.append(metric_result) - # Keep track of metric name. - self.metrics_names.append(metric_name) - - # Keep track of stateful metric attributes (name and metric function). - if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: - self.stateful_metric_names.append(metric_name) - self.stateful_metric_functions.append(metric_fn) - if not context.executing_eagerly(): - # Keep track of updates created by stateful metrics. - self.metrics_updates += metric_fn.updates + if (isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful and + not context.executing_eagerly()): + # Keep track of updates created by stateful metrics. + self.metrics_updates += metric_fn.updates return metric_results def _handle_metrics(self, @@ -754,9 +770,8 @@ class Model(Network): the model. Args: - x: Input data. A `tf.data` dataset. - y: Since `x` is a dataset, `y` should not be specified - (since targets will be obtained from the iterator). + x: Input data. A numpy array or `tf.data` dataset. + y: Target data. A numpy array or None if x is a `tf.data` dataset. sample_weight: An optional sample-weight array passed by the user to weight the importance of each sample in `x`. class_weight: An optional class-weight array by the user to @@ -786,12 +801,51 @@ class Model(Network): raise NotImplementedError('`class_weight` is currently not supported ' 'when using DistributionStrategy.') + # Validates `steps` argument right at the beginning since we use it to + # construct the dataset object. + # TODO(anjalisridhar): This may not be a valid error since we now accept + # numpy array inputs. We still want to assert that we have a populated steps + # parameter. + if check_steps: + if steps is None: + raise ValueError('When using DistributionStrategy, ' + 'you should specify the `{steps_name}` argument.' + .format(steps_name=steps_name)) + + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray): + x_shape = first_x_value.shape + x_dtype = first_x_value.dtype + if batch_size is None: + batch_size = x_shape[0] // steps + if y is not None: + first_y_value = nest.flatten(y)[0] + x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y), + output_types=(x_dtype, first_y_value.dtype), + output_shapes=(x_shape[1:], + first_y_value.shape[1:])) + # TODO(anjalisridhar): What should the buffer size be? + x = x.shuffle(10000) + x = x.repeat() + x = x.batch(batch_size) + y = None + else: + # This case is for the predict call where the dataset only contains + # inputs and no targets i.e it does not return a tuple. + # TODO(anjalisridhar): Raise an error if we are not able to process + # all the predict samples. This can happen if the number of batches is + # not evenly divisible by the number of worker devices. + x = Dataset.from_generator(lambda x=x: x, + output_types=x_dtype, + output_shapes=x_shape[1:]) + x = x.repeat() + x = x.batch(batch_size) + # TODO(anjalisridhar): Can we use the iterator and getnext op cache? # We require users to pass Datasets since we distribute the dataset across # multiple devices. - if not isinstance(x, dataset_ops.Dataset): - raise ValueError('When using DistributionStrategy, model inputs should be' - ' Dataset instances; found instead %s.' % type(x)) + assert isinstance(x, dataset_ops.Dataset) + # TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a # function which returns a Dataset. Currently distribute_dataset() only # accepts a function that returns a Dataset. Once we add support for being @@ -799,12 +853,6 @@ class Model(Network): result = self._distribution_strategy.distribute_dataset(lambda: x) iterator = result.make_initializable_iterator() K.get_session().run(iterator.initializer) - # Validates `steps` argument based on x's type. - if check_steps: - if steps is None: - raise ValueError('When using a Dataset instance as input to a model, ' - 'you should specify the `{steps_name}` argument.' - .format(steps_name=steps_name)) training_utils.validate_iterator_input(x, y, sample_weight, validation_split) @@ -1304,6 +1352,9 @@ class Model(Network): initial_epoch=0, steps_per_epoch=None, validation_steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, **kwargs): """Trains the model for a fixed number of epochs (iterations on a dataset). @@ -1316,19 +1367,23 @@ class Model(Network): - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - A `tf.data` dataset or a dataset iterator. Should return a tuple - of either (inputs, targets) or (inputs, targets, sample_weights). + of either `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A generator or `keras.utils.Sequence` returning `(inputs, targets)` + or `(inputs, targets, sample weights)`. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset or dataset - iterator, `y` should not be specified - (since targets will be obtained from the iterator). + tensor targets, or inversely). If `x` is a dataset, dataset + iterator, generator, or `keras.utils.Sequence` instance, `y` should + not be specified (since targets will be obtained from `x`). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors, datasets, or dataset iterators - (since they generate batches). + form of symbolic tensors, dataset, dataset iterators, + generators, or `keras.utils.Sequence` instances (since they generate + batches). epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. @@ -1350,7 +1405,8 @@ class Model(Network): on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling. This argument is - not supported when `x` is a dataset or a dataset iterator. + not supported when `x` is a dataset, dataset iterator, generator or + `keras.utils.Sequence` instance. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. @@ -1381,8 +1437,9 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator, instead - provide the sample_weights as the third element of `x`. + supported when `x` is a dataset, dataset iterator, generator, or + `keras.utils.Sequence` instance, instead provide the sample_weights + as the third element of `x`. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). @@ -1396,6 +1453,20 @@ class Model(Network): validation_steps: Only relevant if `steps_per_epoch` is specified. Total number of steps (batches of samples) to validate before stopping. + max_queue_size: Integer. Used for generator or `keras.utils.Sequence` + input only. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. + workers: Integer. Used for generator or `keras.utils.Sequence` input + only. Maximum number of processes to spin up + when using process-based threading. If unspecified, `workers` + will default to 1. If 0, will execute the generator on the main + thread. + use_multiprocessing: Boolean. Used for generator or + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to + the generator as they can't be passed easily to children processes. **kwargs: Used for backwards compatibility. Returns: @@ -1412,6 +1483,23 @@ class Model(Network): # TODO(fchollet): this method may be creating reference cycles, which would # lead to accumulating garbage in memory when called in a loop. Investigate. + if data_utils.is_generator_or_sequence(x): + training_utils.check_generator_arguments(y, sample_weight) + return self.fit_generator( + x, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_data=validation_data, + validation_steps=validation_steps, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle, + initial_epoch=initial_epoch) + # Backwards compatibility if batch_size is None and steps_per_epoch is None: batch_size = 32 @@ -1428,6 +1516,13 @@ class Model(Network): if self._distribution_strategy: distributed_training_utils.validate_callbacks(callbacks) + distributed_training_utils.validate_inputs(x, y) + + first_x_value = nest.flatten(x)[0] + if not steps_per_epoch and isinstance(first_x_value, np.ndarray): + steps_per_epoch = distributed_training_utils.get_input_batch_params( + first_x_value, batch_size, self._distribution_strategy) + x, y, sample_weights = self._standardize_user_data( x, y, @@ -1462,6 +1557,13 @@ class Model(Network): 'However we received `validation_data=%s`' % validation_data) # Validate and standardize validation data. + if self._distribution_strategy: + distributed_training_utils.validate_inputs(val_x, val_y) + first_valx_value = nest.flatten(val_x)[0] + if not validation_steps and isinstance(first_valx_value, np.ndarray): + validation_steps = distributed_training_utils.get_input_batch_params( + first_valx_value, batch_size, self._distribution_strategy) + val_x, val_y, val_sample_weights = self._standardize_user_data( val_x, val_y, @@ -1540,7 +1642,10 @@ class Model(Network): batch_size=None, verbose=1, sample_weight=None, - steps=None): + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): """Returns the loss value & metrics values for the model in test mode. Computation is done in batches. @@ -1554,18 +1659,21 @@ class Model(Network): - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - A `tf.data` dataset or a dataset iterator. + - A generator or `keras.utils.Sequence` instance. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely). - If `x` is a dataset or a dataset iterator, `y` should not be specified - (since targets will be obtained from the iterator/dataset). + If `x` is a dataset, dataset iterator, generator or + `keras.utils.Sequence` instance, `y` should not be specified (since + targets will be obtained from the iterator/dataset). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors, datasets, or dataset iterators - (since they generate batches). + form of symbolic tensors, dataset, dataset iterators, + generators, or `keras.utils.Sequence` instances (since they generate + batches). verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. sample_weight: Optional Numpy array of weights for @@ -1579,11 +1687,25 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset or a dataset iterator, instead pass + sample weights as the third element of `x`. steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. Ignored with the default value of `None`. + max_queue_size: Integer. Used for generator or `keras.utils.Sequence` + input only. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. + workers: Integer. Used for generator or `keras.utils.Sequence` input + only. Maximum number of processes to spin up when using + process-based threading. If unspecified, `workers` will default + to 1. If 0, will execute the generator on the main thread. + use_multiprocessing: Boolean. Used for generator or + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to + the generator as they can't be passed easily to children processes. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1594,11 +1716,28 @@ class Model(Network): Raises: ValueError: in case of invalid arguments. """ + if data_utils.is_generator_or_sequence(x): + training_utils.check_generator_arguments(y, sample_weight) + return self.evaluate_generator( + x, + steps=steps, + verbose=verbose, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + # Backwards compatibility. if batch_size is None and steps is None: batch_size = 32 # Validate and standardize user data. + if self._distribution_strategy: + distributed_training_utils.validate_inputs(x, y) + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray) and not steps: + steps = distributed_training_utils.get_input_batch_params( + first_x_value, batch_size, self._distribution_strategy) + x, y, sample_weights = self._standardize_user_data( x, y, @@ -1633,7 +1772,14 @@ class Model(Network): verbose=verbose, steps=steps) - def predict(self, x, batch_size=None, verbose=0, steps=None): + def predict(self, + x, + batch_size=None, + verbose=0, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): """Generates output predictions for the input samples. Computation is done in batches. @@ -1645,16 +1791,32 @@ class Model(Network): - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - A `tf.data` dataset or a dataset iterator. + - A generator or `keras.utils.Sequence` instance. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors, dataset, or dataset iterators - (since they generate batches). + form of symbolic tensors, dataset, dataset iterators, + generators, or `keras.utils.Sequence` instances (since they generate + batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of `None`. + max_queue_size: Integer. Used for generator or `keras.utils.Sequence` + input only. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. + workers: Integer. Used for generator or `keras.utils.Sequence` input + only. Maximum number of processes to spin up when using + process-based threading. If unspecified, `workers` will default + to 1. If 0, will execute the generator on the main thread. + use_multiprocessing: Boolean. Used for generator or + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to + the generator as they can't be passed easily to children processes. + Returns: Numpy array(s) of predictions. @@ -1665,18 +1827,35 @@ class Model(Network): or in case a stateful model receives a number of samples that is not a multiple of the batch size. """ + if data_utils.is_generator_or_sequence(x): + return self.predict_generator( + x, + steps=steps, + verbose=verbose, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + # Backwards compatibility. if batch_size is None and steps is None: batch_size = 32 - # Turn off prefetching since this is currently not deterministic. Once - # b/112498930 is fixed we can turn it back on. - # `_prefetch_on_device` is currently a property of only `MirroredStrategy`. - if (self._distribution_strategy and - hasattr(self._distribution_strategy, '_prefetch_on_device')): - self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access + if self._distribution_strategy: + # Turn off prefetching since this is currently not deterministic. Once + # b/112498930 is fixed we can turn it back on. + # `_prefetch_on_device` is currently a property of only + # `MirroredStrategy`. + if hasattr(self._distribution_strategy, '_prefetch_on_device'): + self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access + distributed_training_utils.validate_inputs(x, None) + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray) and not steps: + steps = distributed_training_utils.get_input_batch_params( + first_x_value, batch_size, self._distribution_strategy) # Validate and standardize user data. + # TODO(anjalisridhar): We don't pass batch_size here for some reason. This + # means that we end up calculating it twice which we should avoid. x, _, _ = self._standardize_user_data( x, check_steps=True, steps_name='steps', steps=steps) @@ -2008,7 +2187,7 @@ class Model(Network): Arguments: generator: Generator yielding tuples (inputs, targets) or (inputs, targets, sample_weights) - or an instance of Sequence (keras.utils.Sequence) + or an instance of `keras.utils.Sequence` object in order to avoid duplicate data when using multiprocessing. steps: Total number of steps (batches of samples) @@ -2072,9 +2251,8 @@ class Model(Network): Arguments: generator: Generator yielding batches of input samples - or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + or an instance of `keras.utils.Sequence` object in order to + avoid duplicate data when using multiprocessing. steps: Total number of steps (batches of samples) to yield from `generator` before stopping. Optional for `Sequence`: if unspecified, will use |