aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-10-31 15:20:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-31 15:23:52 -0700
commitb242a7988ccd3f8f55c7ec494d2d4f76175fb6d8 (patch)
tree38800d5d4ada2d7821928efd4666868acce69db5 /tensorflow/contrib/metrics
parent453dd5848f5652f520eb0faf17a732f20779cdb1 (diff)
Set metric variable initializers as lambda.
PiperOrigin-RevId: 174100686
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py39
1 files changed, 14 insertions, 25 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index c524da4309..c328b03707 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -92,8 +92,7 @@ def _count_condition(values,
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
- count_ = metrics_impl.metric_variable(
- array_ops.zeros([], dtype=dtypes.float32), name='count')
+ count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
values = math_ops.to_float(values)
if weights is not None:
@@ -916,8 +915,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'tp' in includes:
true_positives = metrics_impl.metric_variable(
- array_ops.zeros([num_thresholds], dtype=dtypes.float32),
- name='true_positives')
+ [num_thresholds], dtypes.float32, name='true_positives')
is_true_positive = math_ops.to_float(
math_ops.logical_and(label_is_pos, pred_is_pos))
if weights_tiled is not None:
@@ -929,8 +927,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'fn' in includes:
false_negatives = metrics_impl.metric_variable(
- array_ops.zeros([num_thresholds], dtype=dtypes.float32),
- name='false_negatives')
+ [num_thresholds], dtypes.float32, name='false_negatives')
is_false_negative = math_ops.to_float(
math_ops.logical_and(label_is_pos, pred_is_neg))
if weights_tiled is not None:
@@ -942,8 +939,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'tn' in includes:
true_negatives = metrics_impl.metric_variable(
- array_ops.zeros([num_thresholds], dtype=dtypes.float32),
- name='true_negatives')
+ [num_thresholds], dtypes.float32, name='true_negatives')
is_true_negative = math_ops.to_float(
math_ops.logical_and(label_is_neg, pred_is_neg))
if weights_tiled is not None:
@@ -955,8 +951,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'fp' in includes:
false_positives = metrics_impl.metric_variable(
- array_ops.zeros([num_thresholds], dtype=dtypes.float32),
- name='false_positives')
+ [num_thresholds], dtypes.float32, name='false_positives')
is_false_positive = math_ops.to_float(
math_ops.logical_and(label_is_neg, pred_is_pos))
if weights_tiled is not None:
@@ -1317,9 +1312,9 @@ def streaming_precision_recall_at_equal_thresholds(predictions,
with ops.name_scope('variables'):
tp_buckets_v = metrics_impl.metric_variable(
- array_ops.zeros([num_thresholds], dtype=dtype), name='tp_buckets')
+ [num_thresholds], dtype, name='tp_buckets')
fp_buckets_v = metrics_impl.metric_variable(
- array_ops.zeros([num_thresholds], dtype=dtype), name='fp_buckets')
+ [num_thresholds], dtype, name='fp_buckets')
with ops.name_scope('update_op'):
update_tp = state_ops.scatter_add(
@@ -2582,15 +2577,13 @@ def streaming_covariance(predictions,
predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
- count_ = metrics_impl.metric_variable(
- array_ops.zeros([], dtype=dtypes.float32), name='count')
+ count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
mean_prediction = metrics_impl.metric_variable(
- array_ops.zeros([], dtype=dtypes.float32), name='mean_prediction')
+ [], dtypes.float32, name='mean_prediction')
mean_label = metrics_impl.metric_variable(
- array_ops.zeros([], dtype=dtypes.float32), name='mean_label')
+ [], dtypes.float32, name='mean_label')
comoment = metrics_impl.metric_variable( # C_A in update equation
- array_ops.zeros([], dtype=dtypes.float32),
- name='comoment')
+ [], dtypes.float32, name='comoment')
if weights is None:
batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn
@@ -3011,11 +3004,8 @@ def streaming_concat(values,
init_size = 0 if max_size is None else max_size
init_shape = [init_size] + fixed_shape
array = metrics_impl.metric_variable(
- array_ops.zeros(init_shape, dtype=values.dtype),
- validate_shape=False,
- name='array')
- size = metrics_impl.metric_variable(
- array_ops.zeros([], dtype=dtypes.int32), name='size')
+ init_shape, values.dtype, validate_shape=False, name='array')
+ size = metrics_impl.metric_variable([], dtypes.int32, name='size')
perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
valid_array = array[:size]
@@ -3149,8 +3139,7 @@ def count(values,
"""
with variable_scope.variable_scope(name, 'count', (values, weights)):
- count_ = metrics_impl.metric_variable(
- array_ops.zeros([], dtype=dtypes.float32), name='count')
+ count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))