From 416bac50aaa684049bb3270d379316efc5b960c2 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 25 May 2018 01:06:33 +0200 Subject: [tfgan] Add possibility to export GANEstimator saved model --- tensorflow/contrib/gan/python/estimator/python/head_impl.py | 6 +++++- tensorflow/contrib/gan/python/estimator/python/head_test.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'tensorflow/contrib/gan') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index ff903a78cc..5b5557bd8f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -24,6 +24,7 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import metrics as metrics_lib @@ -182,7 +183,10 @@ class GANHead(head._Head): # pylint: disable=protected-access if mode == model_fn_lib.ModeKeys.PREDICT: return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.PREDICT, - predictions=gan_model.generated_data) + predictions=gan_model.generated_data, + export_outputs={ + 'predict': export_output.PredictOutput(gan_model.generated_data) + }) elif mode == model_fn_lib.ModeKeys.EVAL: gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 6587f1fc60..c121f322b5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -71,13 +71,14 @@ class GANHeadTest(test.TestCase): return {} def _test_modes_helper(self, mode): - self.gan_head.create_estimator_spec( + return self.gan_head.create_estimator_spec( features=None, mode=mode, logits=get_gan_model()) def test_modes_predict(self): - self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + self.assertItemsEqual(('predict',), spec.export_outputs.keys()) def test_modes_eval(self): self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) -- cgit v1.2.3