diff options
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 16 |
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index a0b15c6bd2..c8e106f6c1 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -114,13 +114,15 @@ def _safe_scalar_div(numerator, denominator, name): name=name) -def _create_local(name, shape=None, collections=None, dtype=dtypes.float32): +def _create_local(name, shape, collections=None, validate_shape=True, + dtype=dtypes.float32): """Creates a new local variable. Args: name: The name of the new or existing variable. shape: Shape of the new or existing variable. collections: A list of collection names to which the Variable will be added. + validate_shape: Whether to validate the shape of the variable. dtype: Data type of the variables. Returns: @@ -133,7 +135,8 @@ def _create_local(name, shape=None, collections=None, dtype=dtypes.float32): initial_value=array_ops.zeros(shape, dtype=dtype), name=name, trainable=False, - collections=collections) + collections=collections, + validate_shape=validate_shape) def _count_condition(values, weights=None, metrics_collections=None, @@ -2831,7 +2834,8 @@ def streaming_concat(values, # applied to contiguous slices init_size = 0 if max_size is None else max_size init_shape = [init_size] + fixed_shape - array = _create_local('array', shape=init_shape, dtype=values.dtype) + array = _create_local( + 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) size = _create_local('size', shape=[], dtype=dtypes.int32) perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index bde3a9aa56..40a6879456 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -3714,6 +3714,22 @@ class StreamingConcatTest(tf.test.TestCase): with self.assertRaises(ValueError): metrics.streaming_concat(tf.placeholder(tf.float32, [None, None])) + def testStreamingConcatReset(self): + with self.test_session() as sess: + values = tf.placeholder(tf.int32, [None]) + concatenated, update_op = metrics.streaming_concat(values) + sess.run(tf.initialize_local_variables()) + + self.assertAllEqual([], concatenated.eval()) + + sess.run([update_op], feed_dict={values: [0, 1, 2]}) + self.assertAllEqual([0, 1, 2], concatenated.eval()) + + sess.run(tf.initialize_local_variables()) + + sess.run([update_op], feed_dict={values: [3, 4]}) + self.assertAllEqual([3, 4], concatenated.eval()) + class AggregateMetricsTest(tf.test.TestCase): |