aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/export/export_output_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/export/export_output_test.py')
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index b21ba91b0f..d94c764fd7 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -24,8 +24,10 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -335,5 +337,18 @@ class SupervisedOutputTest(test.TestCase):
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
+ def test_metric_op_is_operation(self):
+ """Tests that ops.Operation is wrapped by a tensor for metric_ops."""
+ loss = {"my_loss": constant_op.constant([0])}
+ predictions = {u"output1": constant_op.constant(["foo"])}
+ metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())}
+
+ outputter = MockSupervisedOutput(loss, predictions, metrics)
+ self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(
+ outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0")
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics/update_op"], ops.Tensor))
+
if __name__ == "__main__":
test.main()