aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r--tensorflow/python/keras/engine/training.py282
1 files changed, 230 insertions, 52 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index d224dfffdd..fed07c4120 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,9 +20,11 @@ from __future__ import print_function
import weakref
import numpy as np
+import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -39,6 +41,7 @@ from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
+from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
@@ -206,8 +209,27 @@ class Model(Network):
for metric in metrics:
metric_fn = training_utils.get_metric_function(
metric, output_shape=output_shape, loss_fn=loss_fn)
- metric_name = self._get_metric_name(
- metric, output_index, weighted=weights is not None)
+
+ if (context.executing_eagerly() and y_true is not None and
+ y_pred is not None):
+ # In eager mode, when executing metric_fn during training, we do not
+ # need to generate unique metric name and add it to the model
+ # as we have done that during compile already.
+ prefix = 'weighted_' if weights is not None else ''
+ suffix = metric_fn.name if hasattr(metric_fn,
+ 'name') else metric_fn.__name__
+ metric_name = prefix + suffix
+ else:
+ # Get metric name that is to be added to the model.
+ metric_name = self._get_metric_name(
+ metric, output_index, weighted=weights is not None)
+ # Keep track of metric name.
+ self.metrics_names.append(metric_name)
+
+ # Keep track of stateful metric attributes (name and metric function).
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
+ self.stateful_metric_names.append(metric_name)
+ self.stateful_metric_functions.append(metric_fn)
with K.name_scope(metric_name):
# If both outputs and targets are available, call the metric function.
@@ -247,16 +269,10 @@ class Model(Network):
self.metrics_tensors.append(metric_result)
metric_results.append(metric_result)
- # Keep track of metric name.
- self.metrics_names.append(metric_name)
-
- # Keep track of stateful metric attributes (name and metric function).
- if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
- self.stateful_metric_names.append(metric_name)
- self.stateful_metric_functions.append(metric_fn)
- if not context.executing_eagerly():
- # Keep track of updates created by stateful metrics.
- self.metrics_updates += metric_fn.updates
+ if (isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful and
+ not context.executing_eagerly()):
+ # Keep track of updates created by stateful metrics.
+ self.metrics_updates += metric_fn.updates
return metric_results
def _handle_metrics(self,
@@ -754,9 +770,8 @@ class Model(Network):
the model.
Args:
- x: Input data. A `tf.data` dataset.
- y: Since `x` is a dataset, `y` should not be specified
- (since targets will be obtained from the iterator).
+ x: Input data. A numpy array or `tf.data` dataset.
+ y: Target data. A numpy array or None if x is a `tf.data` dataset.
sample_weight: An optional sample-weight array passed by the user to
weight the importance of each sample in `x`.
class_weight: An optional class-weight array by the user to
@@ -786,12 +801,51 @@ class Model(Network):
raise NotImplementedError('`class_weight` is currently not supported '
'when using DistributionStrategy.')
+ # Validates `steps` argument right at the beginning since we use it to
+ # construct the dataset object.
+ # TODO(anjalisridhar): This may not be a valid error since we now accept
+ # numpy array inputs. We still want to assert that we have a populated steps
+ # parameter.
+ if check_steps:
+ if steps is None:
+ raise ValueError('When using DistributionStrategy, '
+ 'you should specify the `{steps_name}` argument.'
+ .format(steps_name=steps_name))
+
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray):
+ x_shape = first_x_value.shape
+ x_dtype = first_x_value.dtype
+ if batch_size is None:
+ batch_size = x_shape[0] // steps
+ if y is not None:
+ first_y_value = nest.flatten(y)[0]
+ x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
+ output_types=(x_dtype, first_y_value.dtype),
+ output_shapes=(x_shape[1:],
+ first_y_value.shape[1:]))
+ # TODO(anjalisridhar): What should the buffer size be?
+ x = x.shuffle(10000)
+ x = x.repeat()
+ x = x.batch(batch_size)
+ y = None
+ else:
+ # This case is for the predict call where the dataset only contains
+ # inputs and no targets i.e it does not return a tuple.
+ # TODO(anjalisridhar): Raise an error if we are not able to process
+ # all the predict samples. This can happen if the number of batches is
+ # not evenly divisible by the number of worker devices.
+ x = Dataset.from_generator(lambda x=x: x,
+ output_types=x_dtype,
+ output_shapes=x_shape[1:])
+ x = x.repeat()
+ x = x.batch(batch_size)
+
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
# multiple devices.
- if not isinstance(x, dataset_ops.Dataset):
- raise ValueError('When using DistributionStrategy, model inputs should be'
- ' Dataset instances; found instead %s.' % type(x))
+ assert isinstance(x, dataset_ops.Dataset)
+
# TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a
# function which returns a Dataset. Currently distribute_dataset() only
# accepts a function that returns a Dataset. Once we add support for being
@@ -799,12 +853,6 @@ class Model(Network):
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
K.get_session().run(iterator.initializer)
- # Validates `steps` argument based on x's type.
- if check_steps:
- if steps is None:
- raise ValueError('When using a Dataset instance as input to a model, '
- 'you should specify the `{steps_name}` argument.'
- .format(steps_name=steps_name))
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
@@ -1304,6 +1352,9 @@ class Model(Network):
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False,
**kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
@@ -1316,19 +1367,23 @@ class Model(Network):
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator. Should return a tuple
- of either (inputs, targets) or (inputs, targets, sample_weights).
+ of either `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
+ - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
+ or `(inputs, targets, sample weights)`.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset or dataset
- iterator, `y` should not be specified
- (since targets will be obtained from the iterator).
+ tensor targets, or inversely). If `x` is a dataset, dataset
+ iterator, generator, or `keras.utils.Sequence` instance, `y` should
+ not be specified (since targets will be obtained from `x`).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
- form of symbolic tensors, datasets, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
@@ -1350,7 +1405,8 @@ class Model(Network):
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This argument is
- not supported when `x` is a dataset or a dataset iterator.
+ not supported when `x` is a dataset, dataset iterator, generator or
+ `keras.utils.Sequence` instance.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
@@ -1381,8 +1437,9 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator, instead
- provide the sample_weights as the third element of `x`.
+ supported when `x` is a dataset, dataset iterator, generator, or
+ `keras.utils.Sequence` instance, instead provide the sample_weights
+ as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
@@ -1396,6 +1453,20 @@ class Model(Network):
validation_steps: Only relevant if `steps_per_epoch`
is specified. Total number of steps (batches of samples)
to validate before stopping.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up
+ when using process-based threading. If unspecified, `workers`
+ will default to 1. If 0, will execute the generator on the main
+ thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
**kwargs: Used for backwards compatibility.
Returns:
@@ -1412,6 +1483,23 @@ class Model(Network):
# TODO(fchollet): this method may be creating reference cycles, which would
# lead to accumulating garbage in memory when called in a loop. Investigate.
+ if data_utils.is_generator_or_sequence(x):
+ training_utils.check_generator_arguments(y, sample_weight)
+ return self.fit_generator(
+ x,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_data=validation_data,
+ validation_steps=validation_steps,
+ class_weight=class_weight,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch)
+
# Backwards compatibility
if batch_size is None and steps_per_epoch is None:
batch_size = 32
@@ -1428,6 +1516,13 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
+ distributed_training_utils.validate_inputs(x, y)
+
+ first_x_value = nest.flatten(x)[0]
+ if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
+ steps_per_epoch = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1462,6 +1557,13 @@ class Model(Network):
'However we received `validation_data=%s`' % validation_data)
# Validate and standardize validation data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_inputs(val_x, val_y)
+ first_valx_value = nest.flatten(val_x)[0]
+ if not validation_steps and isinstance(first_valx_value, np.ndarray):
+ validation_steps = distributed_training_utils.get_input_batch_params(
+ first_valx_value, batch_size, self._distribution_strategy)
+
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x,
val_y,
@@ -1540,7 +1642,10 @@ class Model(Network):
batch_size=None,
verbose=1,
sample_weight=None,
- steps=None):
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
"""Returns the loss value & metrics values for the model in test mode.
Computation is done in batches.
@@ -1554,18 +1659,21 @@ class Model(Network):
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
+ - A generator or `keras.utils.Sequence` instance.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely).
- If `x` is a dataset or a dataset iterator, `y` should not be specified
- (since targets will be obtained from the iterator/dataset).
+ If `x` is a dataset, dataset iterator, generator or
+ `keras.utils.Sequence` instance, `y` should not be specified (since
+ targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors, datasets, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
verbose: 0 or 1. Verbosity mode.
0 = silent, 1 = progress bar.
sample_weight: Optional Numpy array of weights for
@@ -1579,11 +1687,25 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead pass
+ sample weights as the third element of `x`.
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
Ignored with the default value of `None`.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up when using
+ process-based threading. If unspecified, `workers` will default
+ to 1. If 0, will execute the generator on the main thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1594,11 +1716,28 @@ class Model(Network):
Raises:
ValueError: in case of invalid arguments.
"""
+ if data_utils.is_generator_or_sequence(x):
+ training_utils.check_generator_arguments(y, sample_weight)
+ return self.evaluate_generator(
+ x,
+ steps=steps,
+ verbose=verbose,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
# Validate and standardize user data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_inputs(x, y)
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray) and not steps:
+ steps = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1633,7 +1772,14 @@ class Model(Network):
verbose=verbose,
steps=steps)
- def predict(self, x, batch_size=None, verbose=0, steps=None):
+ def predict(self,
+ x,
+ batch_size=None,
+ verbose=0,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
"""Generates output predictions for the input samples.
Computation is done in batches.
@@ -1645,16 +1791,32 @@ class Model(Network):
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset or a dataset iterator.
+ - A generator or `keras.utils.Sequence` instance.
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors, dataset, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
verbose: Verbosity mode, 0 or 1.
steps: Total number of steps (batches of samples)
before declaring the prediction round finished.
Ignored with the default value of `None`.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up when using
+ process-based threading. If unspecified, `workers` will default
+ to 1. If 0, will execute the generator on the main thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
+
Returns:
Numpy array(s) of predictions.
@@ -1665,18 +1827,35 @@ class Model(Network):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
+ if data_utils.is_generator_or_sequence(x):
+ return self.predict_generator(
+ x,
+ steps=steps,
+ verbose=verbose,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
- # Turn off prefetching since this is currently not deterministic. Once
- # b/112498930 is fixed we can turn it back on.
- # `_prefetch_on_device` is currently a property of only `MirroredStrategy`.
- if (self._distribution_strategy and
- hasattr(self._distribution_strategy, '_prefetch_on_device')):
- self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+ if self._distribution_strategy:
+ # Turn off prefetching since this is currently not deterministic. Once
+ # b/112498930 is fixed we can turn it back on.
+ # `_prefetch_on_device` is currently a property of only
+ # `MirroredStrategy`.
+ if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+ self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+ distributed_training_utils.validate_inputs(x, None)
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray) and not steps:
+ steps = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
# Validate and standardize user data.
+ # TODO(anjalisridhar): We don't pass batch_size here for some reason. This
+ # means that we end up calculating it twice which we should avoid.
x, _, _ = self._standardize_user_data(
x, check_steps=True, steps_name='steps', steps=steps)
@@ -2008,7 +2187,7 @@ class Model(Network):
Arguments:
generator: Generator yielding tuples (inputs, targets)
or (inputs, targets, sample_weights)
- or an instance of Sequence (keras.utils.Sequence)
+ or an instance of `keras.utils.Sequence`
object in order to avoid duplicate data
when using multiprocessing.
steps: Total number of steps (batches of samples)
@@ -2072,9 +2251,8 @@ class Model(Network):
Arguments:
generator: Generator yielding batches of input samples
- or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data
- when using multiprocessing.
+ or an instance of `keras.utils.Sequence` object in order to
+ avoid duplicate data when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use