diff options
author | Jianwei Xie <xiejw@google.com> | 2018-09-17 13:31:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 13:39:55 -0700 |
commit | 32ed8d488ad8088b63f046cde0c665e3b2aab8e7 (patch) | |
tree | badf4530bbf5f7463276f40600da9b2026dd4ea7 /tensorflow/contrib/tpu | |
parent | a768624f1d0ae3629caf5b9784b4b6911b881c18 (diff) |
Add support for predicting models with learning_phase.
PiperOrigin-RevId: 213327633
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index d8c3872363..776b9bff0f 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -970,15 +970,25 @@ class TPUFunction(object): # Note: this condition is possible during the prologue or epilogue of the # pipelined loop. return None, None - # Strip sample weight from inputs + + if (self.model.uses_learning_phase and + not isinstance(K.learning_phase(), int)): + # Remove the learning_phase flag at the end. We currently hard code the + # learning_phase in TPUFunction. + assert isinstance(inputs[-1], int), ( + 'Expect the final element be learning_phase flag. Got {}'.format( + inputs[-1])) + inputs = inputs[:-1] + if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or self.execution_mode == model_fn_lib.ModeKeys.EVAL): + # Strip sample weight from inputs. input_tensors = self.model._feed_inputs + self.model._feed_targets - inputs = inputs[:len(input_tensors)] - return input_tensors, inputs else: input_tensors = self.model._feed_inputs - return input_tensors, inputs + + inputs = inputs[:len(input_tensors)] + return input_tensors, inputs def _process_outputs(self, outfeed_outputs): """Processes the outputs of a model function execution. |