aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py4
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py237
2 files changed, 193 insertions, 48 deletions
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index a20069c4fe..0495134636 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -58,13 +58,13 @@ def get_input_datasets():
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(100)
- train_ds = train_ds.batch(64)
+ train_ds = train_ds.batch(64, drop_remainder=True)
# eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
eval_ds = eval_ds.shuffle(100)
- eval_ds = eval_ds.batch(64)
+ eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds, input_shape
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index a7bb1f8177..e440e02bfb 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -19,13 +19,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
def fit_loop(
@@ -64,6 +67,11 @@ def fit_loop(
"""
current_strategy = model._distribution_strategy
+ # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+ if current_strategy.__class__.__name__ == 'TPUStrategy':
+ return _experimental_fit_loop(
+ model, iterator, epochs, initial_epoch, steps_per_epoch)
+
clone_model_on_towers(
model, current_strategy, make_callback_model=True)
@@ -116,11 +124,6 @@ def fit_loop(
do_validation = False
if validation_steps:
do_validation = True
- if steps_per_epoch is None:
- raise ValueError('Can only use `validation_steps` '
- 'when doing step-wise '
- 'training, i.e. `steps_per_epoch` '
- 'must be set.')
# Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
@@ -140,44 +143,46 @@ def fit_loop(
verbose=verbose)
out_labels = model.metrics_names or []
callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
- if steps_per_epoch is not None:
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, out_labels, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
- if callbacks.model.stop_training:
- break
- if do_validation:
- val_outs = test_loop(
- model,
- val_iterator,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, out_labels, outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
@@ -192,6 +197,139 @@ def fit_loop(
return model.history
+def _experimental_fit_loop(
+ model,
+ iterator,
+ epochs=100,
+ initial_epoch=0,
+ steps_per_epoch=None):
+ """fit function when using TPU DistributionStrategy for training.
+
+ Arguments:
+ model: Keras Model instance.
+ iterator: Iterator that returns inputs and targets
+ epochs: Number of times to iterate over the data
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run)
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch. Ignored with the default value of `None`.
+
+ Returns:
+ Returns `None`.
+
+ Raises:
+ ValueError: in case of invalid arguments.
+ """
+ current_strategy = model._distribution_strategy
+
+ # TODO(priyag): Add validation that shapes are fully defined for TPU case.
+
+ # TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
+ K.get_session().run(current_strategy.initialize())
+
+ def _per_device_train_function(model):
+ model._make_train_function()
+ return (model.train_function.inputs,
+ model.train_function.outputs,
+ model.train_function.updates_op,
+ model.train_function.session_kwargs)
+
+ # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+ K.set_learning_phase(1)
+
+ def step_fn(ctx, inputs, targets):
+ """Clones the model and calls make_train_function."""
+ # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
+ clone_model_on_towers(
+ model,
+ current_strategy,
+ make_callback_model=True,
+ inputs=inputs,
+ targets=targets)
+
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_train_function, model._grouped_model)
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs,
+ grouped_updates, grouped_session_args, with_loss_tensor=True)
+ combined_fn = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
+
+ # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
+ # something else for different outputs.
+ out_labels = model.metrics_names or []
+ for label, output in zip(out_labels, combined_fn.outputs):
+ ctx.set_last_step_output(label, output,
+ aggregation=distribute_lib.get_loss_reduction())
+
+ # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
+ # feed_dict, session kwargs, run options, run_metadata for now. These should
+ # be handled appropriately
+ return combined_fn.updates_op
+
+ # Add initial dummy values for loss and other metric tensors.
+ initial_loop_values = {}
+ initial_loop_values['loss'] = constant_op.constant(1e7)
+ for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
+ initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+ with current_strategy.scope():
+ # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
+ # steps_per_epoch and number of epochs.
+ ctx = current_strategy.run_steps_on_dataset(
+ step_fn, iterator, iterations=current_strategy.steps_per_run,
+ initial_loop_values=initial_loop_values)
+
+ train_op = ctx.run_op
+ output_tensors = ctx.last_step_outputs
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ assert steps_per_epoch is not None
+
+ # TODO(priyag, sourabhbajaj): Add callbacks support.
+ # TODO(priyag, sourabhbajaj): Add validation.
+ for epoch in range(initial_epoch, epochs):
+ for step_index in range(
+ 0, steps_per_epoch, current_strategy.steps_per_run):
+ try:
+ _, outs = K.get_session().run([train_op, output_tensors])
+ # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
+ # summaries through callbacks.
+ print('Epoch: {}, step_index: {}, loss: {}'.format(
+ epoch, step_index, outs['loss']))
+ for label, out in outs.items():
+ print(label, ': ', out)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ # Copy the weights back from the replicated model to the original model.
+ with current_strategy.scope():
+ updated_weights = current_strategy.unwrap(
+ model._grouped_model)[0].get_weights()
+ model.set_weights(updated_weights)
+
+ K.get_session().run(current_strategy.finalize())
+
+ # TODO(priyag, sourabhbajaj): Return history.
+
+
def test_loop(model, iterator, verbose=0, steps=None):
"""evaluate method to validate a model that uses DistributionStrategy.
@@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None):
]
-def _clone_and_build_model(model):
+def _clone_and_build_model(model, inputs=None, targets=None):
"""Clone and build the given keras_model."""
# We need to set the import here since we run into a circular dependency
# error.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
- cloned_model = models.clone_model(model, input_tensors=None)
+ cloned_model = models.clone_model(model, input_tensors=inputs)
# Compile and build model.
if isinstance(model.optimizer, optimizers.TFOptimizer):
@@ -387,22 +525,29 @@ def _clone_and_build_model(model):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
+ # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
+ # single tensor should be OK but it throws an error in that case.
+ if (targets is not None and not isinstance(targets, list) and
+ not isinstance(targets, dict)):
+ targets = [targets]
cloned_model.compile(
optimizer,
model.loss,
metrics=model.metrics,
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics)
+ weighted_metrics=model.weighted_metrics,
+ target_tensors=targets)
return cloned_model
-def clone_model_on_towers(model, strategy, make_callback_model=False):
+def clone_model_on_towers(
+ model, strategy, make_callback_model=False, inputs=None, targets=None):
"""Create a cloned model on each tower, unless already created."""
if not model._grouped_model:
with strategy.scope():
model._grouped_model = strategy.call_for_each_tower(
- _clone_and_build_model, model)
+ _clone_and_build_model, model, inputs, targets)
if make_callback_model:
model._make_callback_model()