aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-09-18 17:22:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 17:27:14 -0700
commit867449616aa43f9306247cebdd1edac85b70852a (patch)
treeaaaf4060982c3386c043042f0602de3565557fe1 /tensorflow/python/estimator
parent08af8cac22af4cc430e092b6218ca77736efb82c (diff)
Convert the new metric instances to (value_op, update_op) tuple in the EstimatorSpec.
PiperOrigin-RevId: 213548081
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/estimator.py14
-rw-r--r--tensorflow/python/estimator/model_fn.py2
2 files changed, 4 insertions, 12 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index ffe1e30da0..2dc5d099a0 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -41,7 +41,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import metrics
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -1792,18 +1791,9 @@ def _extract_metric_update_ops(eval_dict, distribution=None):
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
for name, value in sorted(six.iteritems(eval_dict)):
- if isinstance(value, metrics.Metric):
- metric_result = value.result()
- # We expect only one update op for every metric when there is no
- # distribution strategy.
- metric_update = value.updates if distribution else value.updates[0]
- else:
- metric_result = value[0]
- metric_update = value[1]
-
- value_ops[name] = metric_result
+ value_ops[name] = value[0]
update_ops.append(
- distribution.group(metric_update) if distribution else metric_update)
+ distribution.group(value[1]) if distribution else value[1])
update_op = control_flow_ops.group(*update_ops) if update_ops else None
return update_op, value_ops
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 0f26a5bba4..824789467d 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -308,6 +308,8 @@ class EstimatorSpec(
for key, value in six.iteritems(eval_metric_ops):
if isinstance(value, Metric):
vars_to_add.update(value.variables)
+ # Convert Metric instances to (value_tensor, update_op) tuple.
+ eval_metric_ops[key] = (value.result(), value.updates[0])
# Remove variables that are in the local variables collection already.
vars_to_add = vars_to_add.difference(local_vars)
for v in vars_to_add: