aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/keras_support.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/keras_support.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py664
1 files changed, 573 insertions, 91 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 7541544382..81798ee423 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -45,6 +45,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import contextlib
import re
@@ -59,11 +60,15 @@ from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import models
@@ -71,7 +76,9 @@ from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -99,6 +106,45 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
+ """An optimizer that averages gradients across TPU shards."""
+
+ def __init__(self, opt, name='KerasCrossShardOptimizer'):
+ """Construct a new cross-shard optimizer.
+
+ Args:
+ opt: An existing `Optimizer` to encapsulate.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "KerasCrossShardOptimizer".
+
+ Raises:
+ ValueError: If reduction is not a valid cross-shard reduction.
+ """
+ super(KerasCrossShardOptimizer, self).__init__()
+ self._name = name
+ self._opt = opt
+
+ def get_updates(self, loss, params):
+ logging.info('Get updates: %s', loss)
+ self._opt.get_gradients = self.get_gradients
+ return self._opt.get_updates(loss, params)
+
+ def get_gradients(self, loss, params):
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
+ return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
+
+ def set_weights(self, weights):
+ self._opt.set_weights()
+
+ def get_weights(self):
+ return self._opt.get_weights()
+
+ @property
+ def lr(self):
+ return self._opt.lr
+
+
class TPUModelOp(
collections.namedtuple('TPUModelOp', [
'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'
@@ -113,8 +159,13 @@ def _valid_name(tensor_name):
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- return keras_optimizers.TFOptimizer(
- optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer))
+ if tpu_function.get_tpu_context().number_of_shards == 1:
+ return opt
+
+ if isinstance(opt, keras_optimizers.TFOptimizer):
+ return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
+ else:
+ return KerasCrossShardOptimizer(opt)
class TPURewriteContext(object):
@@ -154,7 +205,6 @@ class TPURewriteContext(object):
caller_obj = caller_frame.f_locals.get('self')
if (caller_obj is not None and
isinstance(caller_obj, base_layer.Layer) and name is not None):
- logging.info('Intercepted name_scope: %s', caller_obj)
return variable_scope.variable_scope(
name, default_name, values, reuse=variable_scope.AUTO_REUSE)
@@ -163,8 +213,51 @@ class TPURewriteContext(object):
self._default_placeholder = array_ops.placeholder
self._default_name_scope = ops.name_scope
self._default_make_variable = base_layer.make_variable
+ self._default_random_normal = random_ops.random_normal
+ self._default_qr = gen_linalg_ops.qr
array_ops.placeholder = _placeholder
+
+ # Replace random_ops.random_normal with a dummy function because
+ # `random_normal` isn't yet implemented on the TPU. Because these
+ # initialized values are overwritten by the CPU values, this is okay.
+ def random_normal(shape,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ seed=None,
+ name=None):
+ del mean
+ del stddev
+ del seed
+ return array_ops.zeros(shape, dtype=dtype, name=name)
+
+ random_ops.random_normal = random_normal
+
+ # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented.
+ # TODO(saeta): Remove qr override once we confirm the qr implementation is
+ # ok.
+ # pylint: disable=redefined-builtin
+ def qr(input, full_matrices=False, name=None):
+ """Dummy implementation of qr decomposition."""
+ del full_matrices # TODO(saeta): Properly handle the full matrix case.
+ input_shape = input.shape
+ if len(input_shape) < 2:
+ raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+ p = min(input_shape[-1], input_shape[-2])
+ if len(input_shape) == 2:
+ q = array_ops.zeros((p, p), name=name)
+ r = array_ops.zeros(input_shape, name=name)
+ return (r, q)
+ elif len(input_shape) == 3:
+ n = input_shape[0]
+ q = array_ops.zeros((n, p, p), name=name)
+ r = array_ops.zeros(input_shape, name=name)
+ return (r, q)
+ else:
+ raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+ gen_linalg_ops.qr = qr
+
ops.name_scope = _name_scope
base_layer.make_variable = variable_scope.get_variable
logging.info('Overriding default placeholder.')
@@ -174,6 +267,334 @@ class TPURewriteContext(object):
array_ops.placeholder = self._default_placeholder
ops.name_scope = self._default_name_scope
base_layer.make_variable = self._default_make_variable
+ random_ops.random_normal = self._default_random_normal
+ gen_linalg_ops.qr = self._default_qr
+
+
+class SizedInfeed(collections.namedtuple('SizedInfeed',
+ ['sharded_infeed_tensors',
+ 'infeed_ops'])):
+ """Represents an instantiation of the infeed ops for a concrete input shape.
+
+ sharded_infeed_tensors: A data structure of Tensors used to represent the
+ placeholder tensors that must be fed when using feed_dicts.
+
+ infeed_ops: the set of ops that will be run to drive infeed for a single step.
+ """
+ pass
+
+
+class TPUInfeedInstance(object):
+ """TPUInfeedInstance represents the logic to manage feeding in a single step.
+
+ See the comments on the `TPUInfeedManager` for a description for how infeed
+ is managed.
+ """
+
+ @abc.abstractmethod
+ def make_input_specs(self, input_tensors):
+ """Constructs the infeed_specs for the given Infeed instance.
+
+ Args:
+ input_tensors: The inputs to the model.
+
+ Returns:
+ A list of
+ """
+ pass
+
+ def make_feed_dict(self, tpu_model_op):
+ """Constructs a feed_dict for this instance, given the tpu_model_op.
+
+ Args:
+ tpu_model_op: A `TPUModelOp` representing the TPU Model for this
+ instance's input spec.
+
+ Returns:
+ A dictionary to use as the feed_dict of a `session.run` call.
+ """
+ pass
+
+
+class TPUInfeedManager(object):
+ """TPUInfeedManager manages the data infeeding of data to a TPU computation.
+
+ Because there are multiple data sources (e.g. in-memory NumPy arrays,
+ `tf.data.Dataset`s), we abstract the different logic behind a single
+ interface: the `TPUInfeedManager`.
+
+ (1) A `TPUFunction` is called with a set of inputs. Based on the inputs,
+ `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a
+ new one if required).
+
+ (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager`
+ which returns a `TPUInfeedInstance`.
+
+ (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of
+ the model based on the returned `input_specs` from `TPUInfeedInstance`.
+
+ (4) [Optional.] If the model has not already been instantiated for the given
+ input spec, the `TPUFunction` compiles the model for the input spec (using the
+ `TPUInfeedManager`).
+
+ (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the
+ compiled model instance corresponding to its shape.
+ """
+
+ @abc.abstractmethod
+ def make_infeed_instance(self, inputs):
+ """Given a single step's input, construct a `TPUInfeedInstance`.
+
+ Args:
+ inputs: The inputs to a given step.
+
+ Returns:
+ A subclass of `TPUInfeedInstance`.
+ """
+ pass
+
+ @abc.abstractmethod
+ def build_infeed_from_input_specs(self, input_specs, execution_mode):
+ """For a given input specification (size, type), construct the infeed ops.
+
+ This is called only once for a given input specification and builds the
+ graph ops. It does not have a pointer to the actual infeed data.
+
+ Args:
+ input_specs: TODO(saeta): Document me!
+ execution_mode: TODO(saeta): Document me!
+
+ Returns:
+ A `SizedInfeed` instance.
+ """
+ pass
+
+
+class TPUNumpyInfeedManager(TPUInfeedManager):
+ """TPU Infeed manager for Numpy inputs."""
+
+ class NumpyInfeedInstance(TPUInfeedInstance):
+ """Infeed instance for Numpy inputs."""
+
+ def __init__(self, sharded_inputs):
+ self._sharded_inputs = sharded_inputs
+
+ def make_input_specs(self, input_tensors):
+ # Compute an input specification (used to generate infeed enqueue and
+ # dequeue operations). We use the shape from our input array and the
+ # dtype from our model. A user may pass in a float64 for a float32
+ # input: for model compatibility we still must generate a float32 infeed.
+ input_specs = []
+ # We use the shape and dtype from the first shard to compute the input
+ # metadata (`input_specs`); all replicas have the same type and shape.
+ for tensor, ary in zip(input_tensors, self._sharded_inputs[0]):
+ input_specs.append(
+ tensor_spec.TensorSpec(ary.shape, tensor.dtype,
+ _valid_name(tensor.name)))
+
+ return input_specs
+
+ def make_feed_dict(self, tpu_model_op):
+ infeed_dict = {}
+ for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors,
+ self._sharded_inputs):
+ for tensor, value in zip(infeed_tensors, inputs):
+ infeed_dict[tensor] = value
+ return infeed_dict
+
+ def __init__(self, distribution_strategy):
+ self._strategy = distribution_strategy
+
+ def _split_tensors(self, inputs):
+ """Split input data across shards.
+
+ Each input is sliced along the batch axis.
+
+ Args:
+ inputs: List of Numpy arrays to run on the TPU.
+
+ Returns:
+ List of lists containing the input to feed to each TPU shard.
+ """
+ if self._strategy.num_towers == 1:
+ return [inputs]
+
+ batch_size = inputs[0].shape[0]
+ assert batch_size % self._strategy.num_towers == 0, (
+ 'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
+ (batch_size, self._strategy.num_towers))
+ shard_size = batch_size // self._strategy.num_towers
+ input_list = []
+ for index in range(self._strategy.num_towers):
+ shard_inputs = [
+ x[index * shard_size:(index + 1) * shard_size] for x in inputs
+ ]
+ input_list.append(shard_inputs)
+ return input_list
+
+ def make_infeed_instance(self, inputs):
+ sharded_inputs = self._split_tensors(inputs)
+ return self.NumpyInfeedInstance(sharded_inputs)
+
+ def build_infeed_from_input_specs(self, input_specs, execution_mode):
+ infeed_op = []
+ shard_infeed_tensors = []
+
+ for shard_id in range(self._strategy.num_towers):
+ with ops.device('/device:CPU:0'):
+ infeed_tensors = []
+ with ops.device('/device:TPU:%d' % shard_id):
+ for spec in input_specs:
+ # Construct placeholders for each of the inputs.
+ infeed_tensors.append(
+ array_ops.placeholder(
+ dtype=spec.dtype,
+ shape=spec.shape,
+ name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
+ shard_infeed_tensors.append(infeed_tensors)
+
+ infeed_op.append(
+ tpu_ops.infeed_enqueue_tuple(
+ infeed_tensors, [spec.shape for spec in input_specs],
+ name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
+ device_ordinal=shard_id))
+ return SizedInfeed(infeed_ops=infeed_op,
+ sharded_infeed_tensors=shard_infeed_tensors)
+
+
+class TPUDatasetInfeedManager(TPUInfeedManager):
+ """Manages infeed for a `tf.data.Dataset` into a TPU computation.
+ """
+
+ class DatasetInfeedInstance(TPUInfeedInstance):
+ """An instance of the TPU infeed."""
+
+ def __init__(self, input_specs):
+ self._input_specs = input_specs
+
+ def make_input_specs(self, input_tensors):
+ # TODO(saeta): Do error checking here!
+ return self._input_specs
+
+ def make_feed_dict(self, tpu_model_op):
+ # TODO(saeta): Verify tpu_model_op is as expected!
+ return {}
+
+ def __init__(self, dataset, distribution_strategy, tpu_session):
+ """Constructs a TPUDatasetInfeedManager.
+
+ Must be called within a `KerasTPUModel.tpu_session` context!
+
+ Args:
+ dataset: A `tf.data.Dataset` to infeed.
+ distribution_strategy: The `TPUDistributionStrategy` used to configure the
+ Keras TPU model.
+ tpu_session: The `tf.Session` object used for running the TPU model.
+ """
+ self._verify_dataset_shape(dataset)
+ self._dataset = dataset
+ self._strategy = distribution_strategy
+ dummy_x_shape = dataset.output_shapes[0].as_list()
+ dummy_x_shape[0] *= distribution_strategy.num_towers
+ dummy_y_shape = dataset.output_shapes[1].as_list()
+ dummy_y_shape[0] *= distribution_strategy.num_towers
+ self._iterator = dataset.make_initializable_iterator()
+ tpu_session.run(self._iterator.initializer)
+
+ self._get_next_ops = []
+ ctrl_deps = []
+ for i in range(distribution_strategy.num_towers):
+ with ops.control_dependencies(ctrl_deps): # Ensure deterministic
+ # TODO(saeta): Ensure correct placement!
+ get_next_op = self._iterator.get_next()
+ self._get_next_ops.append(get_next_op)
+ ctrl_deps.extend(get_next_op)
+
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ self._dummy_x = np.zeros(dummy_x_shape,
+ dtype=dataset.output_types[0].as_numpy_dtype)
+ self._dummy_y = np.zeros(dummy_y_shape,
+ dtype=dataset.output_types[1].as_numpy_dtype)
+
+ input_specs = []
+ if isinstance(self._iterator.output_shapes, tuple):
+ assert isinstance(self._iterator.output_types, tuple)
+ assert len(self._iterator.output_shapes) == len(
+ self._iterator.output_types)
+ for i in range(len(self._iterator.output_shapes)):
+ spec = tensor_spec.TensorSpec(self._iterator.output_shapes[i],
+ self._iterator.output_types[i])
+ input_specs.append(spec)
+ elif isinstance(self._iterator.output_shapes, tensor_shape.TensorShape):
+ spec = tensor_spec.TensorSpec(self._iterator.output_shapes,
+ self._iterator.output_types)
+ input_specs.append(spec)
+
+ self._infeed_instance = self.DatasetInfeedInstance(input_specs)
+
+ def _verify_dataset_shape(self, dataset):
+ """Verifies a dataset is of an appropriate shape for TPUs."""
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise ValueError('The function passed as the `x` parameter did not '
+ 'return a `tf.data.Dataset`.')
+ if not isinstance(dataset.output_classes, tuple):
+ raise ValueError('The dataset must return a tuple of tf.Tensors, '
+ 'instead it returns: %s' % dataset.output_classes)
+ if len(dataset.output_classes) != 2:
+ raise ValueError(
+ 'The dataset must return a 2-element tuple, got '
+ '%s output classes instead.' % (dataset.output_classes,))
+ for i, cls in enumerate(dataset.output_classes):
+ if cls != ops.Tensor:
+ raise ValueError('The dataset returned a non-Tensor type (%s) at '
+ 'index %d.' % (cls, i))
+ for i, shape in enumerate(dataset.output_shapes):
+ if not shape:
+ raise ValueError('The dataset returns a scalar tensor in '
+ 'tuple index %d. Did you forget to batch? '
+ '(Output shapes: %s).' % (i,
+ dataset.output_shapes))
+ for j, dim in enumerate(shape):
+ if dim.value is None:
+ if j == 0:
+ hint = (' Hint: did you use `ds.batch(BATCH_SIZE, '
+ 'drop_remainder=True)`?')
+ else:
+ hint = ''
+ raise ValueError(
+ 'The Keras-TPU integration for `tf.data` '
+ 'currently requires static shapes. The provided '
+ 'dataset only has a partially defined shape. '
+ '(Dimension %d of output tensor %d is not statically known '
+ 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint))
+
+ @property
+ def dummy_x(self):
+ return self._dummy_x
+
+ @property
+ def dummy_y(self):
+ return self._dummy_y
+
+ def make_infeed_instance(self, inputs):
+ # TODO(saeta): Verify inputs is as expected.
+ return self._infeed_instance
+
+ def build_infeed_from_input_specs(self, input_specs, execution_mode):
+ shard_infeed_tensors = self._get_next_ops
+ assert len(shard_infeed_tensors) == self._strategy.num_towers
+ infeed_ops = []
+ for shard_id in range(self._strategy.num_towers):
+ with ops.device('/device:CPU:0'):
+ infeed_ops.append(
+ tpu_ops.infeed_enqueue_tuple(
+ shard_infeed_tensors[shard_id],
+ [spec.shape for spec in input_specs],
+ name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
+ device_ordinal=shard_id))
+ return SizedInfeed(infeed_ops=infeed_ops,
+ sharded_infeed_tensors=shard_infeed_tensors)
class TPUFunction(object):
@@ -195,7 +616,13 @@ class TPUFunction(object):
self._compilation_cache = {}
self._cloned_model = None
- def _specialize_model(self, input_specs):
+ # Copy optimizer configuration. This is done prior to `_specialize_model`
+ # as the configuration may require evaluating variables in the CPU session.
+ self._optimizer_config = None
+ if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
+ self._optimizer_config = self.model.optimizer.get_config()
+
+ def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
# Re-create our input and output layers inside our subgraph. They will be
# attached to the true computation when we clone our model in `tpu_fn`.
@@ -221,8 +648,8 @@ class TPUFunction(object):
name='infeed-%s' % self.execution_mode)
assert len(infeed_tensors) == len(infeed_layers), (
- 'Infeed inputs did not match model: %s vs %s', (infeed_layers,
- infeed_tensors))
+ 'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
+ infeed_tensors))
tpu_targets = []
tpu_input_map = {}
@@ -236,11 +663,23 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
- self._cloned_model = models.clone_model(self.model)
+ # TODO(power): Replicate variables.
+ with ops.device('/device:TPU:0'):
+ self._cloned_model = models.clone_model(self.model)
+
+ # Create a copy of the optimizer for this graph.
+ if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
+ cloned_optimizer = keras_optimizers.TFOptimizer(
+ self.model.optimizer.optimizer)
+ else:
+ logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
+ self._optimizer_config)
+ cloned_optimizer = self.model.optimizer.__class__.from_config(
+ self._optimizer_config)
if is_training or is_test:
self._cloned_model.compile(
- optimizer=_replicated_optimizer(self.model.optimizer),
+ optimizer=_replicated_optimizer(cloned_optimizer),
loss=self.model.loss,
loss_weights=self.model.loss_weights,
metrics=self.model.metrics,
@@ -299,37 +738,24 @@ class TPUFunction(object):
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
- infeed_op = []
+ sized_infeed = infeed_manager.build_infeed_from_input_specs(
+ input_specs, self.execution_mode)
+ # Build output ops.
outfeed_op = []
- shard_infeed_tensors = []
-
for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:TPU:%d' % shard_id):
- infeed_tensors = []
- for spec in input_specs:
- infeed_tensors.append(
- array_ops.placeholder(
- dtype=spec.dtype,
- shape=spec.shape,
- name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
- shard_infeed_tensors.append(infeed_tensors)
-
- infeed_op.append(
- tpu_ops.infeed_enqueue_tuple(
- infeed_tensors, [spec.shape for spec in input_specs],
- name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id)))
-
+ with ops.device('/device:CPU:0'):
outfeed_op.extend(
tpu_ops.outfeed_dequeue_tuple(
dtypes=[spec.dtype for spec in self._outfeed_spec],
shapes=[spec.shape for spec in self._outfeed_spec],
- name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id)))
+ name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id),
+ device_ordinal=shard_id))
return TPUModelOp(
compile_op,
execute_op,
- infeed_tensors=shard_infeed_tensors,
- infeed_op=infeed_op,
+ infeed_tensors=sized_infeed.sharded_infeed_tensors,
+ infeed_op=sized_infeed.infeed_ops,
outfeed_op=outfeed_op)
def _test_model_compiles(self, tpu_model_ops):
@@ -348,37 +774,17 @@ class TPUFunction(object):
logging.info('Finished compiling. Time elapsed: %s secs',
end_time - start_time)
- def _split_tensors(self, inputs):
- """Split input data across shards.
-
- Each input is sliced along the batch axis.
-
- Args:
- inputs: List of Numpy arrays to run on the TPU.
-
- Returns:
- List of lists containing the input to feed to each TPU shard.
- """
- if self._strategy.num_towers == 1:
- return [inputs]
-
- batch_size = inputs[0].shape[0]
- assert batch_size % self._strategy.num_towers == 0, (
- 'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
- (batch_size, self._strategy.num_towers)
- )
- shard_size = batch_size // self._strategy.num_towers
- input_list = []
- for index in range(self._strategy.num_towers):
- shard_inputs = [
- x[index * shard_size:(index + 1) * shard_size] for x in inputs
- ]
- input_list.append(shard_inputs)
- return input_list
-
def __call__(self, inputs):
assert isinstance(inputs, list)
+ infeed_manager = None
+ for x, mgr in self.model._numpy_to_infeed_manager_list:
+ if inputs[0] is x:
+ infeed_manager = mgr
+ break
+ if infeed_manager is None:
+ infeed_manager = TPUNumpyInfeedManager(self.model._strategy)
+
# Strip sample weight from inputs
if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
self.execution_mode == model_fn_lib.ModeKeys.EVAL):
@@ -387,21 +793,9 @@ class TPUFunction(object):
else:
input_tensors = self.model._feed_inputs
- shard_inputs = self._split_tensors(inputs)
+ infeed_instance = infeed_manager.make_infeed_instance(inputs)
del inputs # To avoid accident usage.
-
- # Compute an input specification (used to generate infeed enqueue and
- # dequeue operations). We use the shape from our input array and the
- # dtype from our model. A user may pass in a float64 for a float32
- # input: for model compatibility we still must generate a float32 infeed.
- input_specs = []
-
- # We use the shape and dtype from the first shard to compute the input
- # metadata (`input_specs`); all replicas have the same type and shape.
- for tensor, ary in zip(input_tensors, shard_inputs[0]):
- input_specs.append(
- tensor_spec.TensorSpec(ary.shape, tensor.dtype,
- _valid_name(tensor.name)))
+ input_specs = infeed_instance.make_input_specs(input_tensors)
# XLA requires every operation in the graph has a fixed shape. To
# handle varying batch sizes we recompile a new sub-graph for each
@@ -412,7 +806,8 @@ class TPUFunction(object):
with self.model.tpu_session():
logging.info('New input shapes; (re-)compiling: mode=%s, %s',
self.execution_mode, input_specs)
- new_tpu_model_ops = self._specialize_model(input_specs)
+ new_tpu_model_ops = self._specialize_model(input_specs,
+ infeed_manager)
self._compilation_cache[shape_key] = new_tpu_model_ops
self._test_model_compiles(new_tpu_model_ops)
@@ -420,11 +815,7 @@ class TPUFunction(object):
self.model._initialize_weights(self._cloned_model)
tpu_model_ops = self._compilation_cache[shape_key]
- infeed_dict = {}
- for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors,
- shard_inputs):
- for tensor, value in zip(infeed_tensors, inputs):
- infeed_dict[tensor] = value
+ infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
with self.model.tpu_session() as session:
_, _, outfeed_outputs = session.run([
@@ -438,9 +829,8 @@ class TPUFunction(object):
outputs_per_replica = len(self._outfeed_spec)
for i in range(self._strategy.num_towers):
- output_group = outfeed_outputs[
- i * outputs_per_replica:(i+1) * outputs_per_replica
- ]
+ output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
+ outputs_per_replica]
for j in range(outputs_per_replica):
outputs[j].append(output_group[j])
@@ -459,6 +849,11 @@ class KerasTPUModel(models.Model):
name=cpu_model.name,
)
+ # Create a mapping from numpy arrays to infeed managers.
+ # Note: uses a list of tuples instead of a map because numpy arrays are
+ # not hashable.
+ self._numpy_to_infeed_manager_list = []
+
self.predict_function = None
self.test_function = None
self.train_function = None
@@ -470,14 +865,16 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = False
self._graph = ops.Graph()
- cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
+ self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu_name_or_address)
- cluster_spec = cluster_resolver.cluster_spec()
+ master = self._cluster_resolver.master()
+ cluster_spec = self._cluster_resolver.cluster_spec()
self._session = tf_session.Session(
graph=self._graph,
- target=cluster_resolver.master(),
+ target=master,
config=config_pb2.ConfigProto(isolate_session_state=True))
+ # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
if cluster_spec:
self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
@@ -529,10 +926,91 @@ class KerasTPUModel(models.Model):
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- # Keras optimizers are not compatible with TPU rewrite
- if not isinstance(self.optimizer, keras_optimizers.TFOptimizer):
+ def fit(self,
+ x=None,
+ y=None,
+ batch_size=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ validation_split=0.,
+ validation_data=None,
+ shuffle=True,
+ class_weight=None,
+ sample_weight=None,
+ initial_epoch=0,
+ steps_per_epoch=None,
+ validation_steps=None,
+ **kwargs):
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with self.tpu_session() as sess:
+ dataset = x()
+ if steps_per_epoch is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps_per_epoch argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ if isinstance(validation_data, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
raise ValueError(
- 'Optimizer must be a TFOptimizer, got: %s' % self.optimizer)
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(validation_data):
+ with self.tpu_session() as sess:
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).fit(
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def _make_train_function(self):
if not self.train_function:
@@ -615,10 +1093,10 @@ class KerasTPUModel(models.Model):
K.set_session(default_session)
def shutdown(self):
- logging.info('Shutting down TPU session.')
- with self.tpu_session() as session:
- session.run(tpu.shutdown_system())
-
+ # TODO(b/111364423): Actually shut down the system.
+ logging.info('Skipping shutting down TPU system.')
+ # with self.tpu_session() as session:
+ # session.run(tpu.shutdown_system())
self._session.close()
@@ -652,7 +1130,7 @@ Output shape: %(output_shape)s
'layer': layer,
'input_shape': layer.input_shape,
'output_shape': layer.output_shape
- })
+ })
@experimental
@@ -687,6 +1165,10 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
Returns:
A new `KerasTPUModel` instance.
"""
+ # Force initialization of the CPU model.
+ model.get_weights()
+ model.reset_states()
+
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?