diff options
-rw-r--r-- | tensorflow/contrib/metrics/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 129 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 86 |
3 files changed, 217 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 396f85d434..9926d98046 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -122,6 +122,7 @@ time. @@streaming_sparse_precision_at_k @@streaming_sparse_recall_at_k @@streaming_specificity_at_sensitivity +@@streaming_concat @@auc_using_histogram @@ -151,6 +152,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_ma from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean_absolute_error diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 2c56898fb5..6c52f6820b 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -2745,6 +2745,135 @@ def streaming_mean_iou(predictions, return mean_iou, update_op +def _next_array_size(required_size, growth_factor=1.5): + """Calculate the next size for reallocating a dynamic array. + + Args: + required_size: number or tf.Tensor specifying required array capacity. + growth_factor: optional number or tf.Tensor specifying the growth factor + between subsequent allocations. + + Returns: + tf.Tensor with dtype=int32 giving the next array size. + """ + exponent = math_ops.ceil( + math_ops.log(math_ops.cast(required_size, dtypes.float32)) + / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) + return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) + + +def streaming_concat(values, + axis=0, + max_size=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Concatenate values along an axis across batches. + + The function `streaming_concat` creates two local variables, `array` and + `size`, that are used to store concatenated values. Internally, `array` is + used as storage for a dynamic array (if `maxsize` is `None`), which ensures + that updates can be run in amortized constant time. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that appends the values of a tensor and returns the + `value` of the concatenated tensors. + + This op allows for evaluating metrics that cannot be updated incrementally + using the same framework as other streaming metrics. + + Args: + values: tensor to concatenate. Rank and the shape along all axes other than + the axis to concatenate along must be statically known. + axis: optional integer axis to concatenate along. + max_size: optional integer maximum size of `value` along the given axis. + Once the maximum size is reached, further updates are no-ops. By default, + there is no maximum size: the array is resized as necessary. + metrics_collections: An optional list of collections that `value` + should be added to. + updates_collections: An optional list of collections `update_op` should be + added to. + name: An optional variable_scope name. + + Returns: + value: A tensor representing the concatenated values. + update_op: An operation that concatenates the next values. + + Raises: + ValueError: if `values` does not have a statically known rank, `axis` is + not in the valid range or the size of `values` is not statically known + along any axis other than `axis`. + """ + with variable_scope.variable_scope(name, 'streaming_concat', [values]): + # pylint: disable=invalid-slice-index + values_shape = values.get_shape() + if values_shape.dims is None: + raise ValueError('`values` must have known statically known rank') + + ndim = len(values_shape) + if axis < 0: + axis += ndim + if not 0 <= axis < ndim: + raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) + + fixed_shape = [dim.value for n, dim in enumerate(values_shape) + if n != axis] + if any(value is None for value in fixed_shape): + raise ValueError('all dimensions of `values` other than the dimension to ' + 'concatenate along must have statically known size') + + # We move `axis` to the front of the internal array so assign ops can be + # applied to contiguous slices + init_size = 0 if max_size is None else max_size + init_shape = [init_size] + fixed_shape + array = _create_local('array', shape=init_shape, dtype=values.dtype) + size = _create_local('size', shape=[], dtype=dtypes.int32) + + perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] + valid_array = array[:size] + valid_array.set_shape([None] + fixed_shape) + value = array_ops.transpose(valid_array, perm, name='concat') + + values_size = array_ops.shape(values)[axis] + if max_size is None: + batch_size = values_size + else: + batch_size = math_ops.minimum(values_size, max_size - size) + + perm = [axis] + [n for n in range(ndim) if n != axis] + batch_values = array_ops.transpose(values, perm)[:batch_size] + + def reallocate(): + next_size = _next_array_size(new_size) + next_shape = array_ops.pack([next_size] + fixed_shape) + new_value = array_ops.zeros(next_shape, dtype=values.dtype) + old_value = array.value() + assign_op = state_ops.assign(array, new_value, validate_shape=False) + with ops.control_dependencies([assign_op]): + copy_op = array[:size].assign(old_value[:size]) + # return value needs to be the same dtype as no_op() for cond + with ops.control_dependencies([copy_op]): + return control_flow_ops.no_op() + + new_size = size + batch_size + array_size = array_ops.shape_internal(array, optimize=False)[0] + maybe_reallocate_op = control_flow_ops.cond( + new_size > array_size, reallocate, control_flow_ops.no_op) + with ops.control_dependencies([maybe_reallocate_op]): + append_values_op = array[size:new_size].assign(batch_values) + with ops.control_dependencies([append_values_op]): + update_op = size.assign(new_size) + + if metrics_collections: + ops.add_to_collections(metrics_collections, value) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return value, update_op + # pylint: enable=invalid-slice-index + + def aggregate_metrics(*value_update_tuples): """Aggregates the metric value tensors and update ops into two lists. diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 2f33f4b648..efcd1de4fe 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -3593,6 +3593,92 @@ class StreamingMeanIOUTest(tf.test.TestCase): self.assertAlmostEqual(desired_miou, miou.eval()) +class StreamingConcatTest(tf.test.TestCase): + + def setUp(self): + tf.reset_default_graph() + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + value, _ = metrics.streaming_concat( + values=tf.ones((10,)), + metrics_collections=[my_collection_name]) + self.assertListEqual(tf.get_collection(my_collection_name), [value]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_concat( + values=tf.ones((10,)), + updates_collections=[my_collection_name]) + self.assertListEqual(tf.get_collection(my_collection_name), [update_op]) + + def testNextArraySize(self): + next_array_size = metrics.python.ops.metric_ops._next_array_size + with self.test_session(): + self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2) + self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4) + self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4) + self.assertEqual(next_array_size(5, growth_factor=2).eval(), 8) + self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8) + + def testStreamingConcat(self): + with self.test_session() as sess: + values = tf.placeholder(tf.int32, [None]) + concatenated, update_op = metrics.streaming_concat(values) + sess.run(tf.initialize_local_variables()) + + self.assertAllEqual([], concatenated.eval()) + + sess.run([update_op], feed_dict={values: [0, 1, 2]}) + self.assertAllEqual([0, 1, 2], concatenated.eval()) + + sess.run([update_op], feed_dict={values: [3, 4]}) + self.assertAllEqual([0, 1, 2, 3, 4], concatenated.eval()) + + sess.run([update_op], feed_dict={values: [5, 6, 7, 8, 9]}) + self.assertAllEqual(np.arange(10), concatenated.eval()) + + def testStreamingConcatMaxSize(self): + with self.test_session() as sess: + values = tf.range(3) + concatenated, update_op = metrics.streaming_concat(values, max_size=5) + sess.run(tf.initialize_local_variables()) + + self.assertAllEqual([], concatenated.eval()) + + sess.run([update_op]) + self.assertAllEqual([0, 1, 2], concatenated.eval()) + + sess.run([update_op]) + self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval()) + + sess.run([update_op]) + self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval()) + + def testStreamingConcat2D(self): + with self.test_session() as sess: + values = tf.reshape(tf.range(3), (3, 1)) + concatenated, update_op = metrics.streaming_concat(values, axis=-1) + sess.run(tf.initialize_local_variables()) + for _ in range(10): + sess.run([update_op]) + self.assertAllEqual([[0] * 10, [1] * 10, [2] * 10], + concatenated.eval()) + + def testStreamingConcatErrors(self): + with self.assertRaises(ValueError): + metrics.streaming_concat(tf.placeholder(tf.float32)) + + values = tf.zeros((2, 3)) + with self.assertRaises(ValueError): + metrics.streaming_concat(values, axis=-3, max_size=3) + with self.assertRaises(ValueError): + metrics.streaming_concat(values, axis=2, max_size=3) + + with self.assertRaises(ValueError): + metrics.streaming_concat(tf.placeholder(tf.float32, [None, None])) + + class AggregateMetricsTest(tf.test.TestCase): def testAggregateNoMetricsRaisesValueError(self): |