diff options
author | 2018-09-26 17:58:14 -0700 | |
---|---|---|
committer | 2018-09-26 18:10:03 -0700 | |
commit | 5d61748f4e9998c9d2017bd01864b8fcb6d2127a (patch) | |
tree | 891094fbc8cfdad89769f2bb6ed802b932d43549 /tensorflow/contrib/tpu | |
parent | c3af9dc70ae6c5df811c91c29da432469cb471fc (diff) |
Fixed the bug which slows the TPU traning.
PiperOrigin-RevId: 214702243
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index f67e0e6aca..448676c95e 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -817,7 +817,8 @@ def _inject_tpu_inputs_for_dataset(tpu_assignment, mode, return input_specs, get_next_ops -def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs): +def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, + core_id_place_holder, input_tensors, inputs): """Append core information to the set of inputs.""" # This is used during compilation to identify the current TPU core and enable # concatenation operations across cores. @@ -825,8 +826,6 @@ def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs): return input_tensors, inputs # Puts a place holder in input spec. - core_id_place_holder = array_ops.placeholder( - dtype=dtypes.int32, shape=[1], name='core_id') input_tensors = [core_id_place_holder] + input_tensors # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the @@ -874,6 +873,10 @@ class TPUFunction(object): self._compilation_cache = {} self._cloned_model = None self._cloned_optimizer = None + # Create a placeholder for the TPU core ID. Cache the placeholder to avoid + # modifying the graph for every batch. + self._core_id_place_holder = array_ops.placeholder( + dtype=dtypes.int32, shape=[1], name='core_id') def _specialize_model(self, input_specs, infeed_manager): """Specialize `self.model` (a Keras model) for the given input shapes.""" @@ -1141,7 +1144,8 @@ class TPUFunction(object): inputs = inputs[:len(input_tensors)] input_tensors, inputs = ( _inject_tpu_inputs_for_infeed( - self._tpu_assignment, self.execution_mode, input_tensors, inputs)) + self._tpu_assignment, self.execution_mode, + self._core_id_place_holder, input_tensors, inputs)) return input_tensors, inputs def _process_outputs(self, outfeed_outputs): |