diff options
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index dbf5c66c9e..e22aeb2ac0 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1073,6 +1073,52 @@ class KerasTPUModel(models.Model): finally: self._numpy_to_infeed_manager_list = [] + def evaluate(self, + x=None, + y=None, + batch_size=None, + verbose=1, + sample_weight=None, + steps=None): + assert not self._numpy_to_infeed_manager_list # Ensure empty. + + infeed_managers = [] # Managers to clean up at the end of the fit call. + if isinstance(x, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(x): + with self.tpu_session() as sess: + dataset = x() + if steps is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must be ' + 'None') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) + + self._numpy_to_infeed_manager_list = infeed_managers + try: + return super(KerasTPUModel, self).evaluate( + x, + y, + batch_size, + verbose, + sample_weight, + steps) + finally: + self._numpy_to_infeed_manager_list = [] + def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( |