diff options
author | Jianwei Xie <xiejw@google.com> | 2018-09-25 13:36:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 13:41:04 -0700 |
commit | 976fb3105312bb17accebcbca2ebae906bcf99fb (patch) | |
tree | c74965a3429937982c21ebb42e4b7cedd16a52ff /tensorflow/contrib/tpu | |
parent | e51963ead78b3c1c4ab0077a3e43fb9c0f6ab374 (diff) |
Add outputs and target cross replica concat, so each core sees the same output and targets and produces the same loss and metrics.
PiperOrigin-RevId: 214494877
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 155 |
1 files changed, 146 insertions, 9 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 93ae68d254..03e06b8142 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -68,6 +68,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -229,6 +230,39 @@ class TPUEmbedding(embeddings.Embedding): return math_ops.tensordot(inputs, self.embeddings, 1) +def _cross_replica_concat(tensor, core_id, num_cores, name): + """Concatenate `tensor` across cores. + + Args: + tensor: The tensor to be concatenated. Must be [int32 and float32]. + core_id: Tensor indicating the current TPU core. + num_cores: Python int. The total number of TPU cores in the system. + name: The string name to print for debugging. + + Returns: + The same concatenated Tensor on each core. + """ + + input_dtype = tensor.dtype + if input_dtype not in [dtypes.float32, dtypes.int32]: + raise TypeError('For model replication, only (float32 and int32) is ' + 'supported for model outputs and targets. Got {} for ' + '{}.'.format(input_dtype, name)) + + batch_size = tensor.shape[0] + mask = math_ops.to_float(math_ops.equal(range(num_cores), core_id)) + mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims) + result = mask * math_ops.to_float(tensor) + local_tensor_with_holes = array_ops.reshape(result, + [-1] + result.shape.as_list()[2:]) + concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes) + concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:])) + + if concat_tensor != input_dtype: + concat_tensor = math_ops.cast(concat_tensor, input_dtype) + return concat_tensor + + class KerasCrossShardOptimizer(keras_optimizers.Optimizer): """An optimizer that averages gradients across TPU shards.""" @@ -617,7 +651,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): return {} # pylint: disable=redefined-outer-name - def __init__(self, dataset, tpu_assignment, tpu_session): + def __init__(self, dataset, tpu_assignment, tpu_session, mode): """Constructs a TPUDatasetInfeedManager. Must be called within a `KerasTPUModel.tpu_session` context! @@ -627,8 +661,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager): tpu_assignment: The `TPUAssignment` used to configure the Keras TPU model. tpu_session: The `tf.Session` object used for running the TPU model. + mode: ModeKeys enum. """ self._verify_dataset_shape(dataset) + self._dataset = dataset self._tpu_assignment = tpu_assignment dummy_x_shape = dataset.output_shapes[0].as_list() @@ -668,6 +704,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager): self._iterator.output_types) input_specs.append(spec) + # Pre-process the inputs and get_next_ops before caching. + input_specs, self._get_next_ops = ( + _inject_tpu_inputs_for_dataset( + tpu_assignment, mode, input_specs, self._get_next_ops)) self._infeed_instance = self.DatasetInfeedInstance(input_specs) def _verify_dataset_shape(self, dataset): @@ -735,6 +775,71 @@ class TPUDatasetInfeedManager(TPUInfeedManager): sharded_infeed_tensors=shard_infeed_tensors) +def _inject_tpu_inputs_for_dataset(tpu_assignment, mode, + input_specs, get_next_ops): + """Append core information to the set of dataset inputs.""" + # This is used during compilation to identify the current TPU core and enable + # concatenation operations across cores. + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]: + return input_specs, get_next_ops + + # Dataset inputs operate on per core basis. + per_core_batch_size = input_specs[0].shape.as_list()[0] + + # Insert, at head, the tensor for core_id. + assert len(get_next_ops) == tpu_assignment.num_towers + for i in range(tpu_assignment.num_towers): + core_id_constant = constant_op.constant( + np.array([i] * per_core_batch_size).astype('int32'), + dtype=dtypes.int32, + name='cord_id_constant') + get_next_ops[i] = [core_id_constant] + list(get_next_ops[i]) + + # Insert the input spec at head also. + input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32) + ] + input_specs + + return input_specs, get_next_ops + + +def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs): + """Append core information to the set of inputs.""" + # This is used during compilation to identify the current TPU core and enable + # concatenation operations across cores. + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]: + return input_tensors, inputs + + # Puts a place holder in input spec. + core_id_place_holder = array_ops.placeholder( + dtype=dtypes.int32, shape=[1], name='core_id') + input_tensors = [core_id_place_holder] + input_tensors + + # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the + # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id + # (duplicated). + num_cores = tpu_assignment.num_towers + per_core_batch_size = inputs[0].shape[0] // num_cores + core_ids = np.arange(num_cores).repeat(per_core_batch_size) + inputs = [core_ids] + inputs + return input_tensors, inputs + + +def _read_tpu_coreid_from_infeed(mode, infeed_tensors): + """Popping out the core ids from infeed.""" + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]: + return None, infeed_tensors + + if len(infeed_tensors) <= 1: + raise RuntimeError( + 'The infeed tensors on TPU core has only {} tensors. ' + 'This is not expected. Please report a bug.\nTensors: {}'.format( + len(infeed_tensors), infeed_tensors)) + + core_id = infeed_tensors[0][0] # Pop out the scalar version. + rest = infeed_tensors[1:] + return core_id, rest + + class TPUFunction(object): """K.function compatible interface for invoking a TPU compiled function. @@ -785,6 +890,10 @@ class TPUFunction(object): shapes=[spec.shape for spec in input_specs], name='infeed-%s' % self.execution_mode) + core_id, infeed_tensors = ( + _read_tpu_coreid_from_infeed( + mode=self.execution_mode, infeed_tensors=infeed_tensors)) + assert len(infeed_tensors) == len(infeed_layers), ( 'Infeed inputs did not match model: %s vs %s' % (infeed_layers, infeed_tensors)) @@ -806,6 +915,28 @@ class TPUFunction(object): self._tpu_assignment.num_towers): self._cloned_model = models.clone_model(self.model) + # When running on more than one core, concatenate outputs at the end of + # processing. In backprop stage, the gradients will be calculdated + # according to the local inputs as gradient of cross-replica-concat being + # zero for any outputs other than those from mlocal core so the loss + # calculation is identical. + num_towers = self.model._tpu_assignment.num_towers + if num_towers > 1 and (is_training or is_test): + new_outputs = [ + _cross_replica_concat( + o, core_id, num_towers, name='model output ({})'.format(o.name)) + for o in self._cloned_model.outputs + ] + self._cloned_model.outputs = new_outputs + tpu_targets = [ + _cross_replica_concat( + tensor, + core_id, + num_towers, + name='model target ({})'.format(tensor.name)) + for tensor in tpu_targets + ] + # Create a copy of the optimizer for this graph. if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): cloned_optimizer = keras_optimizers.TFOptimizer( @@ -933,6 +1064,7 @@ class TPUFunction(object): for x, mgr in self.model._numpy_to_infeed_manager_list: if inputs[0] is x: return mgr + return TPUNumpyInfeedManager(self.model._tpu_assignment) def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager): @@ -958,8 +1090,10 @@ class TPUFunction(object): shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs]) if shape_key not in self._compilation_cache: with self.model.tpu_session(): - logging.info('New input shapes; (re-)compiling: mode=%s, %s', - self.execution_mode, input_specs) + logging.info( + 'New input shapes; (re-)compiling: mode=%s ' + '(# of cores %d), %s', self.execution_mode, + self._tpu_assignment.num_towers, input_specs) new_tpu_model_ops = self._specialize_model(input_specs, infeed_manager) self._compilation_cache[shape_key] = new_tpu_model_ops @@ -998,6 +1132,9 @@ class TPUFunction(object): input_tensors = self.model._feed_inputs inputs = inputs[:len(input_tensors)] + input_tensors, inputs = ( + _inject_tpu_inputs_for_infeed( + self._tpu_assignment, self.execution_mode, input_tensors, inputs)) return input_tensors, inputs def _process_outputs(self, outfeed_outputs): @@ -1272,8 +1409,8 @@ class KerasTPUModel(models.Model): if y is not None: raise ValueError('When using tf.data as input to a model, y must be ' 'None') - infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, - sess) + infeed_manager = TPUDatasetInfeedManager( + dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.TRAIN) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x @@ -1294,8 +1431,8 @@ class KerasTPUModel(models.Model): if validation_steps is None: raise ValueError('When using tf.data as validation for a model, you ' 'should specify the validation_steps argument.') - infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, - sess) + infeed_manager = TPUDatasetInfeedManager( + dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. val_x = infeed_manager.dummy_x @@ -1372,8 +1509,8 @@ class KerasTPUModel(models.Model): if y is not None: raise ValueError('When using tf.data as input to a model, y must be ' 'None') - infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, - sess) + infeed_manager = TPUDatasetInfeedManager( + dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x |