aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-13 16:01:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-13 17:17:21 -0700
commitc2cfc048b8080b8f3e7990d07fa99cf4deb7935b (patch)
treeaad1c7ea57be5b3384d183f784efcc8793aca1e0
parent8bc3020c0c0113d3349cb5cdbfc6146ae8c28ce9 (diff)
Allows metrics.streaming_concat() to be used inside evaluation.evaluation_loop().
Change: 136104979
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py10
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py16
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index a0b15c6bd2..c8e106f6c1 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -114,13 +114,15 @@ def _safe_scalar_div(numerator, denominator, name):
name=name)
-def _create_local(name, shape=None, collections=None, dtype=dtypes.float32):
+def _create_local(name, shape, collections=None, validate_shape=True,
+ dtype=dtypes.float32):
"""Creates a new local variable.
Args:
name: The name of the new or existing variable.
shape: Shape of the new or existing variable.
collections: A list of collection names to which the Variable will be added.
+ validate_shape: Whether to validate the shape of the variable.
dtype: Data type of the variables.
Returns:
@@ -133,7 +135,8 @@ def _create_local(name, shape=None, collections=None, dtype=dtypes.float32):
initial_value=array_ops.zeros(shape, dtype=dtype),
name=name,
trainable=False,
- collections=collections)
+ collections=collections,
+ validate_shape=validate_shape)
def _count_condition(values, weights=None, metrics_collections=None,
@@ -2831,7 +2834,8 @@ def streaming_concat(values,
# applied to contiguous slices
init_size = 0 if max_size is None else max_size
init_shape = [init_size] + fixed_shape
- array = _create_local('array', shape=init_shape, dtype=values.dtype)
+ array = _create_local(
+ 'array', shape=init_shape, validate_shape=False, dtype=values.dtype)
size = _create_local('size', shape=[], dtype=dtypes.int32)
perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index bde3a9aa56..40a6879456 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -3714,6 +3714,22 @@ class StreamingConcatTest(tf.test.TestCase):
with self.assertRaises(ValueError):
metrics.streaming_concat(tf.placeholder(tf.float32, [None, None]))
+ def testStreamingConcatReset(self):
+ with self.test_session() as sess:
+ values = tf.placeholder(tf.int32, [None])
+ concatenated, update_op = metrics.streaming_concat(values)
+ sess.run(tf.initialize_local_variables())
+
+ self.assertAllEqual([], concatenated.eval())
+
+ sess.run([update_op], feed_dict={values: [0, 1, 2]})
+ self.assertAllEqual([0, 1, 2], concatenated.eval())
+
+ sess.run(tf.initialize_local_variables())
+
+ sess.run([update_op], feed_dict={values: [3, 4]})
+ self.assertAllEqual([3, 4], concatenated.eval())
+
class AggregateMetricsTest(tf.test.TestCase):