aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-19 14:34:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 14:38:06 -0700
commit6e02d79ba0179a23679e65b31405c591726bc552 (patch)
tree0eccc5a682485ec25a3f12cd1d9a0158f17d2472 /tensorflow/contrib/metrics
parent334a8ad32253a1194991d244ca821ceabc69dd71 (diff)
Make count metric consistent with other metrics by converting variable to tensor (_aggregate_variable() returns Tensor).
PiperOrigin-RevId: 205303531
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py18
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py5
2 files changed, 17 insertions, 6 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index b14202ff9e..a328670526 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -3715,6 +3715,7 @@ def count(values,
name=None):
"""Computes the number of examples, or sum of `weights`.
+ This metric keeps track of the denominator in `tf.metrics.mean`.
When evaluating some metric (e.g. mean) on one or more subsets of the data,
this auxiliary metric is useful for keeping track of how many examples there
are in each subset.
@@ -3741,15 +3742,21 @@ def count(values,
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
+ RuntimeError: If eager execution is enabled.
"""
+ if context.executing_eagerly():
+ raise RuntimeError('tf.contrib.metrics.count is not supported when eager '
+ 'execution is enabled.')
with variable_scope.variable_scope(name, 'count', (values, weights)):
+
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))
else:
- _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ values = math_ops.to_float(values)
+ values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=values,
labels=None,
weights=weights)
@@ -3758,15 +3765,14 @@ def count(values,
num_values = math_ops.reduce_sum(weights)
with ops.control_dependencies([values]):
- update_op = state_ops.assign_add(count_, num_values)
+ update_count_op = state_ops.assign_add(count_, num_values)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, count_)
+ count_ = metrics_impl._aggregate_variable(count_, metrics_collections) # pylint: disable=protected-access
if updates_collections:
- ops.add_to_collections(updates_collections, update_op)
+ ops.add_to_collections(updates_collections, update_count_op)
- return count_, update_op
+ return count_, update_count_op
def cohen_kappa(labels,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index a09fc4abd4..401fedcbed 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -6854,6 +6854,11 @@ class CountTest(test.TestCase):
array_ops.ones([4, 3]), updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+ def testReturnType(self):
+ c, op = metrics.count(array_ops.ones([4, 3]))
+ self.assertTrue(isinstance(c, ops.Tensor))
+ self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
+
def testBasic(self):
with self.test_session() as sess:
values_queue = data_flow_ops.FIFOQueue(