aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/export/export_output.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/export/export_output.py')
-rw-r--r--tensorflow/python/estimator/export/export_output.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 6c26d29985..20382a58d8 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -23,6 +23,7 @@ import abc
import six
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
@@ -338,8 +339,16 @@ class _SupervisedOutput(ExportOutput):
raise ValueError(
'{} update_op must be a Tensor or Operation; got {}.'.format(
key, metric_op))
+
+ # We must wrap any ops in a Tensor before export, as the SignatureDef
+ # proto expects tensors only. See b/109740581
+ metric_op_tensor = metric_op
+ if isinstance(metric_op, ops.Operation):
+ with ops.control_dependencies([metric_op]):
+ metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')
+
outputs[val_name] = metric_val
- outputs[op_name] = metric_op
+ outputs[op_name] = metric_op_tensor
return outputs