aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Lukas Geiger <lukas.geiger94@gmail.com>2018-06-15 00:55:56 +0200
committerGravatar Lukas Geiger <lukas.geiger94@gmail.com>2018-06-15 00:55:56 +0200
commit0a6a85a7b720b4ae41d6029d2a5293ae01f66090 (patch)
treebacd3b9f63a0ab47f6cc06492c5fe9a4bff259b7 /tensorflow/contrib/gan
parent416bac50aaa684049bb3270d379316efc5b960c2 (diff)
[tfgan] Add default serving key to unittest
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index c121f322b5..5309d87765 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training
+_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
@@ -78,7 +81,8 @@ class GANHeadTest(test.TestCase):
def test_modes_predict(self):
spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
- self.assertItemsEqual(('predict',), spec.export_outputs.keys())
+ self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'),
+ spec.export_outputs.keys())
def test_modes_eval(self):
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)