aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Lukas Geiger <lukas.geiger94@gmail.com>2018-05-25 01:06:33 +0200
committerGravatar Lukas Geiger <lukas.geiger94@gmail.com>2018-05-25 01:06:33 +0200
commit416bac50aaa684049bb3270d379316efc5b960c2 (patch)
tree6e2f8bce18826770e4985a321714759e26919b07 /tensorflow/contrib/gan
parenta9761960e282cdcf0822951dec86372181f0e88e (diff)
[tfgan] Add possibility to export GANEstimator saved model
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py6
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py5
2 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index ff903a78cc..5b5557bd8f 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head
+from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -182,7 +183,10 @@ class GANHead(head._Head): # pylint: disable=protected-access
if mode == model_fn_lib.ModeKeys.PREDICT:
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.PREDICT,
- predictions=gan_model.generated_data)
+ predictions=gan_model.generated_data,
+ export_outputs={
+ 'predict': export_output.PredictOutput(gan_model.generated_data)
+ })
elif mode == model_fn_lib.ModeKeys.EVAL:
gan_loss = self.create_loss(
features=None, mode=mode, logits=gan_model, labels=None)
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index 6587f1fc60..c121f322b5 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -71,13 +71,14 @@ class GANHeadTest(test.TestCase):
return {}
def _test_modes_helper(self, mode):
- self.gan_head.create_estimator_spec(
+ return self.gan_head.create_estimator_spec(
features=None,
mode=mode,
logits=get_gan_model())
def test_modes_predict(self):
- self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
+ spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
+ self.assertItemsEqual(('predict',), spec.export_outputs.keys())
def test_modes_eval(self):
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)