aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-09-17 13:31:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 13:39:55 -0700
commit32ed8d488ad8088b63f046cde0c665e3b2aab8e7 (patch)
treebadf4530bbf5f7463276f40600da9b2026dd4ea7 /tensorflow/contrib/tpu
parenta768624f1d0ae3629caf5b9784b4b6911b881c18 (diff)
Add support for predicting models with learning_phase.
PiperOrigin-RevId: 213327633
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index d8c3872363..776b9bff0f 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -970,15 +970,25 @@ class TPUFunction(object):
# Note: this condition is possible during the prologue or epilogue of the
# pipelined loop.
return None, None
- # Strip sample weight from inputs
+
+ if (self.model.uses_learning_phase and
+ not isinstance(K.learning_phase(), int)):
+ # Remove the learning_phase flag at the end. We currently hard code the
+ # learning_phase in TPUFunction.
+ assert isinstance(inputs[-1], int), (
+ 'Expect the final element be learning_phase flag. Got {}'.format(
+ inputs[-1]))
+ inputs = inputs[:-1]
+
if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ # Strip sample weight from inputs.
input_tensors = self.model._feed_inputs + self.model._feed_targets
- inputs = inputs[:len(input_tensors)]
- return input_tensors, inputs
else:
input_tensors = self.model._feed_inputs
- return input_tensors, inputs
+
+ inputs = inputs[:len(input_tensors)]
+ return input_tensors, inputs
def _process_outputs(self, outfeed_outputs):
"""Processes the outputs of a model function execution.