diff options
author | Pavithra Vijay <psv@google.com> | 2018-09-18 17:22:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 17:27:14 -0700 |
commit | 867449616aa43f9306247cebdd1edac85b70852a (patch) | |
tree | aaaf4060982c3386c043042f0602de3565557fe1 /tensorflow/python/estimator | |
parent | 08af8cac22af4cc430e092b6218ca77736efb82c (diff) |
Convert the new metric instances to (value_op, update_op) tuple in the EstimatorSpec.
PiperOrigin-RevId: 213548081
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 14 | ||||
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 2 |
2 files changed, 4 insertions, 12 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index ffe1e30da0..2dc5d099a0 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -41,7 +41,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import metrics from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import metrics as metrics_lib @@ -1792,18 +1791,9 @@ def _extract_metric_update_ops(eval_dict, distribution=None): value_ops = {} # Sort metrics lexicographically so graph is identical every time. for name, value in sorted(six.iteritems(eval_dict)): - if isinstance(value, metrics.Metric): - metric_result = value.result() - # We expect only one update op for every metric when there is no - # distribution strategy. - metric_update = value.updates if distribution else value.updates[0] - else: - metric_result = value[0] - metric_update = value[1] - - value_ops[name] = metric_result + value_ops[name] = value[0] update_ops.append( - distribution.group(metric_update) if distribution else metric_update) + distribution.group(value[1]) if distribution else value[1]) update_op = control_flow_ops.group(*update_ops) if update_ops else None return update_op, value_ops diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 0f26a5bba4..824789467d 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -308,6 +308,8 @@ class EstimatorSpec( for key, value in six.iteritems(eval_metric_ops): if isinstance(value, Metric): vars_to_add.update(value.variables) + # Convert Metric instances to (value_tensor, update_op) tuple. + eval_metric_ops[key] = (value.result(), value.updates[0]) # Remove variables that are in the local variables collection already. vars_to_add = vars_to_add.difference(local_vars) for v in vars_to_add: |