aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Stephan Hoyer <shoyer@google.com>2016-09-26 14:09:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-26 15:17:44 -0700
commite7b2114e0732a11cf610b388f92a46fa8f9da3c4 (patch)
treecbeef510a7659e3235895997b0c59ea02190e16a
parent9751084da88a30b03eda84e17d88379fa16289ff (diff)
Add tf.contrib.metrics.streaming_concat
This op is useful, on small datasets, for evaluating metrics that otherwise cannot be updated incrementally. Change: 134332647
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py129
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py86
3 files changed, 217 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 396f85d434..9926d98046 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -122,6 +122,7 @@ time.
@@streaming_sparse_precision_at_k
@@streaming_sparse_recall_at_k
@@streaming_specificity_at_sensitivity
+@@streaming_concat
@@auc_using_histogram
@@ -151,6 +152,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_ma
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean_absolute_error
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 2c56898fb5..6c52f6820b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2745,6 +2745,135 @@ def streaming_mean_iou(predictions,
return mean_iou, update_op
+def _next_array_size(required_size, growth_factor=1.5):
+ """Calculate the next size for reallocating a dynamic array.
+
+ Args:
+ required_size: number or tf.Tensor specifying required array capacity.
+ growth_factor: optional number or tf.Tensor specifying the growth factor
+ between subsequent allocations.
+
+ Returns:
+ tf.Tensor with dtype=int32 giving the next array size.
+ """
+ exponent = math_ops.ceil(
+ math_ops.log(math_ops.cast(required_size, dtypes.float32))
+ / math_ops.log(math_ops.cast(growth_factor, dtypes.float32)))
+ return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32)
+
+
+def streaming_concat(values,
+ axis=0,
+ max_size=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Concatenate values along an axis across batches.
+
+ The function `streaming_concat` creates two local variables, `array` and
+ `size`, that are used to store concatenated values. Internally, `array` is
+ used as storage for a dynamic array (if `maxsize` is `None`), which ensures
+ that updates can be run in amortized constant time.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that appends the values of a tensor and returns the
+ `value` of the concatenated tensors.
+
+ This op allows for evaluating metrics that cannot be updated incrementally
+ using the same framework as other streaming metrics.
+
+ Args:
+ values: tensor to concatenate. Rank and the shape along all axes other than
+ the axis to concatenate along must be statically known.
+ axis: optional integer axis to concatenate along.
+ max_size: optional integer maximum size of `value` along the given axis.
+ Once the maximum size is reached, further updates are no-ops. By default,
+ there is no maximum size: the array is resized as necessary.
+ metrics_collections: An optional list of collections that `value`
+ should be added to.
+ updates_collections: An optional list of collections `update_op` should be
+ added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ value: A tensor representing the concatenated values.
+ update_op: An operation that concatenates the next values.
+
+ Raises:
+ ValueError: if `values` does not have a statically known rank, `axis` is
+ not in the valid range or the size of `values` is not statically known
+ along any axis other than `axis`.
+ """
+ with variable_scope.variable_scope(name, 'streaming_concat', [values]):
+ # pylint: disable=invalid-slice-index
+ values_shape = values.get_shape()
+ if values_shape.dims is None:
+ raise ValueError('`values` must have known statically known rank')
+
+ ndim = len(values_shape)
+ if axis < 0:
+ axis += ndim
+ if not 0 <= axis < ndim:
+ raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
+
+ fixed_shape = [dim.value for n, dim in enumerate(values_shape)
+ if n != axis]
+ if any(value is None for value in fixed_shape):
+ raise ValueError('all dimensions of `values` other than the dimension to '
+ 'concatenate along must have statically known size')
+
+ # We move `axis` to the front of the internal array so assign ops can be
+ # applied to contiguous slices
+ init_size = 0 if max_size is None else max_size
+ init_shape = [init_size] + fixed_shape
+ array = _create_local('array', shape=init_shape, dtype=values.dtype)
+ size = _create_local('size', shape=[], dtype=dtypes.int32)
+
+ perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
+ valid_array = array[:size]
+ valid_array.set_shape([None] + fixed_shape)
+ value = array_ops.transpose(valid_array, perm, name='concat')
+
+ values_size = array_ops.shape(values)[axis]
+ if max_size is None:
+ batch_size = values_size
+ else:
+ batch_size = math_ops.minimum(values_size, max_size - size)
+
+ perm = [axis] + [n for n in range(ndim) if n != axis]
+ batch_values = array_ops.transpose(values, perm)[:batch_size]
+
+ def reallocate():
+ next_size = _next_array_size(new_size)
+ next_shape = array_ops.pack([next_size] + fixed_shape)
+ new_value = array_ops.zeros(next_shape, dtype=values.dtype)
+ old_value = array.value()
+ assign_op = state_ops.assign(array, new_value, validate_shape=False)
+ with ops.control_dependencies([assign_op]):
+ copy_op = array[:size].assign(old_value[:size])
+ # return value needs to be the same dtype as no_op() for cond
+ with ops.control_dependencies([copy_op]):
+ return control_flow_ops.no_op()
+
+ new_size = size + batch_size
+ array_size = array_ops.shape_internal(array, optimize=False)[0]
+ maybe_reallocate_op = control_flow_ops.cond(
+ new_size > array_size, reallocate, control_flow_ops.no_op)
+ with ops.control_dependencies([maybe_reallocate_op]):
+ append_values_op = array[size:new_size].assign(batch_values)
+ with ops.control_dependencies([append_values_op]):
+ update_op = size.assign(new_size)
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, value)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return value, update_op
+ # pylint: enable=invalid-slice-index
+
+
def aggregate_metrics(*value_update_tuples):
"""Aggregates the metric value tensors and update ops into two lists.
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 2f33f4b648..efcd1de4fe 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -3593,6 +3593,92 @@ class StreamingMeanIOUTest(tf.test.TestCase):
self.assertAlmostEqual(desired_miou, miou.eval())
+class StreamingConcatTest(tf.test.TestCase):
+
+ def setUp(self):
+ tf.reset_default_graph()
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ value, _ = metrics.streaming_concat(
+ values=tf.ones((10,)),
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(tf.get_collection(my_collection_name), [value])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.streaming_concat(
+ values=tf.ones((10,)),
+ updates_collections=[my_collection_name])
+ self.assertListEqual(tf.get_collection(my_collection_name), [update_op])
+
+ def testNextArraySize(self):
+ next_array_size = metrics.python.ops.metric_ops._next_array_size
+ with self.test_session():
+ self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
+ self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
+ self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4)
+ self.assertEqual(next_array_size(5, growth_factor=2).eval(), 8)
+ self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8)
+
+ def testStreamingConcat(self):
+ with self.test_session() as sess:
+ values = tf.placeholder(tf.int32, [None])
+ concatenated, update_op = metrics.streaming_concat(values)
+ sess.run(tf.initialize_local_variables())
+
+ self.assertAllEqual([], concatenated.eval())
+
+ sess.run([update_op], feed_dict={values: [0, 1, 2]})
+ self.assertAllEqual([0, 1, 2], concatenated.eval())
+
+ sess.run([update_op], feed_dict={values: [3, 4]})
+ self.assertAllEqual([0, 1, 2, 3, 4], concatenated.eval())
+
+ sess.run([update_op], feed_dict={values: [5, 6, 7, 8, 9]})
+ self.assertAllEqual(np.arange(10), concatenated.eval())
+
+ def testStreamingConcatMaxSize(self):
+ with self.test_session() as sess:
+ values = tf.range(3)
+ concatenated, update_op = metrics.streaming_concat(values, max_size=5)
+ sess.run(tf.initialize_local_variables())
+
+ self.assertAllEqual([], concatenated.eval())
+
+ sess.run([update_op])
+ self.assertAllEqual([0, 1, 2], concatenated.eval())
+
+ sess.run([update_op])
+ self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
+
+ sess.run([update_op])
+ self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
+
+ def testStreamingConcat2D(self):
+ with self.test_session() as sess:
+ values = tf.reshape(tf.range(3), (3, 1))
+ concatenated, update_op = metrics.streaming_concat(values, axis=-1)
+ sess.run(tf.initialize_local_variables())
+ for _ in range(10):
+ sess.run([update_op])
+ self.assertAllEqual([[0] * 10, [1] * 10, [2] * 10],
+ concatenated.eval())
+
+ def testStreamingConcatErrors(self):
+ with self.assertRaises(ValueError):
+ metrics.streaming_concat(tf.placeholder(tf.float32))
+
+ values = tf.zeros((2, 3))
+ with self.assertRaises(ValueError):
+ metrics.streaming_concat(values, axis=-3, max_size=3)
+ with self.assertRaises(ValueError):
+ metrics.streaming_concat(values, axis=2, max_size=3)
+
+ with self.assertRaises(ValueError):
+ metrics.streaming_concat(tf.placeholder(tf.float32, [None, None]))
+
+
class AggregateMetricsTest(tf.test.TestCase):
def testAggregateNoMetricsRaisesValueError(self):