diff options
author | Mustafa Ispir <ispir@google.com> | 2017-10-31 15:20:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-31 15:23:52 -0700 |
commit | b242a7988ccd3f8f55c7ec494d2d4f76175fb6d8 (patch) | |
tree | 38800d5d4ada2d7821928efd4666868acce69db5 /tensorflow/contrib/metrics | |
parent | 453dd5848f5652f520eb0faf17a732f20779cdb1 (diff) |
Set metric variable initializers as lambda.
PiperOrigin-RevId: 174100686
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 39 |
1 files changed, 14 insertions, 25 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index c524da4309..c328b03707 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -92,8 +92,7 @@ def _count_condition(values, or tuple. """ check_ops.assert_type(values, dtypes.bool) - count_ = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') values = math_ops.to_float(values) if weights is not None: @@ -916,8 +915,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'tp' in includes: true_positives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='true_positives') + [num_thresholds], dtypes.float32, name='true_positives') is_true_positive = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: @@ -929,8 +927,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'fn' in includes: false_negatives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='false_negatives') + [num_thresholds], dtypes.float32, name='false_negatives') is_false_negative = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: @@ -942,8 +939,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'tn' in includes: true_negatives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='true_negatives') + [num_thresholds], dtypes.float32, name='true_negatives') is_true_negative = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: @@ -955,8 +951,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'fp' in includes: false_positives = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtypes.float32), - name='false_positives') + [num_thresholds], dtypes.float32, name='false_positives') is_false_positive = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: @@ -1317,9 +1312,9 @@ def streaming_precision_recall_at_equal_thresholds(predictions, with ops.name_scope('variables'): tp_buckets_v = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtype), name='tp_buckets') + [num_thresholds], dtype, name='tp_buckets') fp_buckets_v = metrics_impl.metric_variable( - array_ops.zeros([num_thresholds], dtype=dtype), name='fp_buckets') + [num_thresholds], dtype, name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -2582,15 +2577,13 @@ def streaming_covariance(predictions, predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - count_ = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') mean_prediction = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='mean_prediction') + [], dtypes.float32, name='mean_prediction') mean_label = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='mean_label') + [], dtypes.float32, name='mean_label') comoment = metrics_impl.metric_variable( # C_A in update equation - array_ops.zeros([], dtype=dtypes.float32), - name='comoment') + [], dtypes.float32, name='comoment') if weights is None: batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn @@ -3011,11 +3004,8 @@ def streaming_concat(values, init_size = 0 if max_size is None else max_size init_shape = [init_size] + fixed_shape array = metrics_impl.metric_variable( - array_ops.zeros(init_shape, dtype=values.dtype), - validate_shape=False, - name='array') - size = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.int32), name='size') + init_shape, values.dtype, validate_shape=False, name='array') + size = metrics_impl.metric_variable([], dtypes.int32, name='size') perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] valid_array = array[:size] @@ -3149,8 +3139,7 @@ def count(values, """ with variable_scope.variable_scope(name, 'count', (values, weights)): - count_ = metrics_impl.metric_variable( - array_ops.zeros([], dtype=dtypes.float32), name='count') + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') if weights is None: num_values = math_ops.to_float(array_ops.size(values)) |