aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-09-25 13:36:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 13:41:04 -0700
commit976fb3105312bb17accebcbca2ebae906bcf99fb (patch)
treec74965a3429937982c21ebb42e4b7cedd16a52ff /tensorflow/contrib/tpu
parente51963ead78b3c1c4ab0077a3e43fb9c0f6ab374 (diff)
Add outputs and target cross replica concat, so each core sees the same output and targets and produces the same loss and metrics.
PiperOrigin-RevId: 214494877
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py155
1 files changed, 146 insertions, 9 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 93ae68d254..03e06b8142 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -68,6 +68,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -229,6 +230,39 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+def _cross_replica_concat(tensor, core_id, num_cores, name):
+ """Concatenate `tensor` across cores.
+
+ Args:
+ tensor: The tensor to be concatenated. Must be [int32 and float32].
+ core_id: Tensor indicating the current TPU core.
+ num_cores: Python int. The total number of TPU cores in the system.
+ name: The string name to print for debugging.
+
+ Returns:
+ The same concatenated Tensor on each core.
+ """
+
+ input_dtype = tensor.dtype
+ if input_dtype not in [dtypes.float32, dtypes.int32]:
+ raise TypeError('For model replication, only (float32 and int32) is '
+ 'supported for model outputs and targets. Got {} for '
+ '{}.'.format(input_dtype, name))
+
+ batch_size = tensor.shape[0]
+ mask = math_ops.to_float(math_ops.equal(range(num_cores), core_id))
+ mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
+ result = mask * math_ops.to_float(tensor)
+ local_tensor_with_holes = array_ops.reshape(result,
+ [-1] + result.shape.as_list()[2:])
+ concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
+ concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
+
+ if concat_tensor != input_dtype:
+ concat_tensor = math_ops.cast(concat_tensor, input_dtype)
+ return concat_tensor
+
+
class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
@@ -617,7 +651,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
return {}
# pylint: disable=redefined-outer-name
- def __init__(self, dataset, tpu_assignment, tpu_session):
+ def __init__(self, dataset, tpu_assignment, tpu_session, mode):
"""Constructs a TPUDatasetInfeedManager.
Must be called within a `KerasTPUModel.tpu_session` context!
@@ -627,8 +661,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
tpu_session: The `tf.Session` object used for running the TPU model.
+ mode: ModeKeys enum.
"""
self._verify_dataset_shape(dataset)
+
self._dataset = dataset
self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
@@ -668,6 +704,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
self._iterator.output_types)
input_specs.append(spec)
+ # Pre-process the inputs and get_next_ops before caching.
+ input_specs, self._get_next_ops = (
+ _inject_tpu_inputs_for_dataset(
+ tpu_assignment, mode, input_specs, self._get_next_ops))
self._infeed_instance = self.DatasetInfeedInstance(input_specs)
def _verify_dataset_shape(self, dataset):
@@ -735,6 +775,71 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
sharded_infeed_tensors=shard_infeed_tensors)
+def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
+ input_specs, get_next_ops):
+ """Append core information to the set of dataset inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_specs, get_next_ops
+
+ # Dataset inputs operate on per core basis.
+ per_core_batch_size = input_specs[0].shape.as_list()[0]
+
+ # Insert, at head, the tensor for core_id.
+ assert len(get_next_ops) == tpu_assignment.num_towers
+ for i in range(tpu_assignment.num_towers):
+ core_id_constant = constant_op.constant(
+ np.array([i] * per_core_batch_size).astype('int32'),
+ dtype=dtypes.int32,
+ name='cord_id_constant')
+ get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
+
+ # Insert the input spec at head also.
+ input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
+ ] + input_specs
+
+ return input_specs, get_next_ops
+
+
+def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs):
+ """Append core information to the set of inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_tensors, inputs
+
+ # Puts a place holder in input spec.
+ core_id_place_holder = array_ops.placeholder(
+ dtype=dtypes.int32, shape=[1], name='core_id')
+ input_tensors = [core_id_place_holder] + input_tensors
+
+ # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
+ # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
+ # (duplicated).
+ num_cores = tpu_assignment.num_towers
+ per_core_batch_size = inputs[0].shape[0] // num_cores
+ core_ids = np.arange(num_cores).repeat(per_core_batch_size)
+ inputs = [core_ids] + inputs
+ return input_tensors, inputs
+
+
+def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
+ """Popping out the core ids from infeed."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return None, infeed_tensors
+
+ if len(infeed_tensors) <= 1:
+ raise RuntimeError(
+ 'The infeed tensors on TPU core has only {} tensors. '
+ 'This is not expected. Please report a bug.\nTensors: {}'.format(
+ len(infeed_tensors), infeed_tensors))
+
+ core_id = infeed_tensors[0][0] # Pop out the scalar version.
+ rest = infeed_tensors[1:]
+ return core_id, rest
+
+
class TPUFunction(object):
"""K.function compatible interface for invoking a TPU compiled function.
@@ -785,6 +890,10 @@ class TPUFunction(object):
shapes=[spec.shape for spec in input_specs],
name='infeed-%s' % self.execution_mode)
+ core_id, infeed_tensors = (
+ _read_tpu_coreid_from_infeed(
+ mode=self.execution_mode, infeed_tensors=infeed_tensors))
+
assert len(infeed_tensors) == len(infeed_layers), (
'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
infeed_tensors))
@@ -806,6 +915,28 @@ class TPUFunction(object):
self._tpu_assignment.num_towers):
self._cloned_model = models.clone_model(self.model)
+ # When running on more than one core, concatenate outputs at the end of
+ # processing. In backprop stage, the gradients will be calculdated
+ # according to the local inputs as gradient of cross-replica-concat being
+ # zero for any outputs other than those from mlocal core so the loss
+ # calculation is identical.
+ num_towers = self.model._tpu_assignment.num_towers
+ if num_towers > 1 and (is_training or is_test):
+ new_outputs = [
+ _cross_replica_concat(
+ o, core_id, num_towers, name='model output ({})'.format(o.name))
+ for o in self._cloned_model.outputs
+ ]
+ self._cloned_model.outputs = new_outputs
+ tpu_targets = [
+ _cross_replica_concat(
+ tensor,
+ core_id,
+ num_towers,
+ name='model target ({})'.format(tensor.name))
+ for tensor in tpu_targets
+ ]
+
# Create a copy of the optimizer for this graph.
if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
cloned_optimizer = keras_optimizers.TFOptimizer(
@@ -933,6 +1064,7 @@ class TPUFunction(object):
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
return mgr
+
return TPUNumpyInfeedManager(self.model._tpu_assignment)
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
@@ -958,8 +1090,10 @@ class TPUFunction(object):
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
if shape_key not in self._compilation_cache:
with self.model.tpu_session():
- logging.info('New input shapes; (re-)compiling: mode=%s, %s',
- self.execution_mode, input_specs)
+ logging.info(
+ 'New input shapes; (re-)compiling: mode=%s '
+ '(# of cores %d), %s', self.execution_mode,
+ self._tpu_assignment.num_towers, input_specs)
new_tpu_model_ops = self._specialize_model(input_specs,
infeed_manager)
self._compilation_cache[shape_key] = new_tpu_model_ops
@@ -998,6 +1132,9 @@ class TPUFunction(object):
input_tensors = self.model._feed_inputs
inputs = inputs[:len(input_tensors)]
+ input_tensors, inputs = (
+ _inject_tpu_inputs_for_infeed(
+ self._tpu_assignment, self.execution_mode, input_tensors, inputs))
return input_tensors, inputs
def _process_outputs(self, outfeed_outputs):
@@ -1272,8 +1409,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.TRAIN)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1294,8 +1431,8 @@ class KerasTPUModel(models.Model):
if validation_steps is None:
raise ValueError('When using tf.data as validation for a model, you '
'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
val_x = infeed_manager.dummy_x
@@ -1372,8 +1509,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x