aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-09-26 17:58:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 18:10:03 -0700
commit5d61748f4e9998c9d2017bd01864b8fcb6d2127a (patch)
tree891094fbc8cfdad89769f2bb6ed802b932d43549 /tensorflow/contrib/tpu
parentc3af9dc70ae6c5df811c91c29da432469cb471fc (diff)
Fixed the bug which slows the TPU traning.
PiperOrigin-RevId: 214702243
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index f67e0e6aca..448676c95e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -817,7 +817,8 @@ def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
return input_specs, get_next_ops
-def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs):
+def _inject_tpu_inputs_for_infeed(tpu_assignment, mode,
+ core_id_place_holder, input_tensors, inputs):
"""Append core information to the set of inputs."""
# This is used during compilation to identify the current TPU core and enable
# concatenation operations across cores.
@@ -825,8 +826,6 @@ def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs):
return input_tensors, inputs
# Puts a place holder in input spec.
- core_id_place_holder = array_ops.placeholder(
- dtype=dtypes.int32, shape=[1], name='core_id')
input_tensors = [core_id_place_holder] + input_tensors
# Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
@@ -874,6 +873,10 @@ class TPUFunction(object):
self._compilation_cache = {}
self._cloned_model = None
self._cloned_optimizer = None
+ # Create a placeholder for the TPU core ID. Cache the placeholder to avoid
+ # modifying the graph for every batch.
+ self._core_id_place_holder = array_ops.placeholder(
+ dtype=dtypes.int32, shape=[1], name='core_id')
def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
@@ -1141,7 +1144,8 @@ class TPUFunction(object):
inputs = inputs[:len(input_tensors)]
input_tensors, inputs = (
_inject_tpu_inputs_for_infeed(
- self._tpu_assignment, self.execution_mode, input_tensors, inputs))
+ self._tpu_assignment, self.execution_mode,
+ self._core_id_place_holder, input_tensors, inputs))
return input_tensors, inputs
def _process_outputs(self, outfeed_outputs):