aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/metric_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index b14202ff9e..a328670526 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -3715,6 +3715,7 @@ def count(values,
name=None):
"""Computes the number of examples, or sum of `weights`.
+ This metric keeps track of the denominator in `tf.metrics.mean`.
When evaluating some metric (e.g. mean) on one or more subsets of the data,
this auxiliary metric is useful for keeping track of how many examples there
are in each subset.
@@ -3741,15 +3742,21 @@ def count(values,
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
+ RuntimeError: If eager execution is enabled.
"""
+ if context.executing_eagerly():
+ raise RuntimeError('tf.contrib.metrics.count is not supported when eager '
+ 'execution is enabled.')
with variable_scope.variable_scope(name, 'count', (values, weights)):
+
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))
else:
- _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ values = math_ops.to_float(values)
+ values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=values,
labels=None,
weights=weights)
@@ -3758,15 +3765,14 @@ def count(values,
num_values = math_ops.reduce_sum(weights)
with ops.control_dependencies([values]):
- update_op = state_ops.assign_add(count_, num_values)
+ update_count_op = state_ops.assign_add(count_, num_values)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, count_)
+ count_ = metrics_impl._aggregate_variable(count_, metrics_collections) # pylint: disable=protected-access
if updates_collections:
- ops.add_to_collections(updates_collections, update_op)
+ ops.add_to_collections(updates_collections, update_count_op)
- return count_, update_op
+ return count_, update_count_op
def cohen_kappa(labels,