aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/metric_ops.py
diff options
context:
space:
mode:
authorGravatar Stephan Hoyer <shoyer@google.com>2016-09-26 14:09:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-26 15:17:44 -0700
commite7b2114e0732a11cf610b388f92a46fa8f9da3c4 (patch)
treecbeef510a7659e3235895997b0c59ea02190e16a /tensorflow/contrib/metrics/python/ops/metric_ops.py
parent9751084da88a30b03eda84e17d88379fa16289ff (diff)
Add tf.contrib.metrics.streaming_concat
This op is useful, on small datasets, for evaluating metrics that otherwise cannot be updated incrementally. Change: 134332647
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 2c56898fb5..6c52f6820b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2745,6 +2745,135 @@ def streaming_mean_iou(predictions,
return mean_iou, update_op
+def _next_array_size(required_size, growth_factor=1.5):
+ """Calculate the next size for reallocating a dynamic array.
+
+ Args:
+ required_size: number or tf.Tensor specifying required array capacity.
+ growth_factor: optional number or tf.Tensor specifying the growth factor
+ between subsequent allocations.
+
+ Returns:
+ tf.Tensor with dtype=int32 giving the next array size.
+ """
+ exponent = math_ops.ceil(
+ math_ops.log(math_ops.cast(required_size, dtypes.float32))
+ / math_ops.log(math_ops.cast(growth_factor, dtypes.float32)))
+ return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32)
+
+
+def streaming_concat(values,
+ axis=0,
+ max_size=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Concatenate values along an axis across batches.
+
+ The function `streaming_concat` creates two local variables, `array` and
+ `size`, that are used to store concatenated values. Internally, `array` is
+ used as storage for a dynamic array (if `maxsize` is `None`), which ensures
+ that updates can be run in amortized constant time.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that appends the values of a tensor and returns the
+ `value` of the concatenated tensors.
+
+ This op allows for evaluating metrics that cannot be updated incrementally
+ using the same framework as other streaming metrics.
+
+ Args:
+ values: tensor to concatenate. Rank and the shape along all axes other than
+ the axis to concatenate along must be statically known.
+ axis: optional integer axis to concatenate along.
+ max_size: optional integer maximum size of `value` along the given axis.
+ Once the maximum size is reached, further updates are no-ops. By default,
+ there is no maximum size: the array is resized as necessary.
+ metrics_collections: An optional list of collections that `value`
+ should be added to.
+ updates_collections: An optional list of collections `update_op` should be
+ added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ value: A tensor representing the concatenated values.
+ update_op: An operation that concatenates the next values.
+
+ Raises:
+ ValueError: if `values` does not have a statically known rank, `axis` is
+ not in the valid range or the size of `values` is not statically known
+ along any axis other than `axis`.
+ """
+ with variable_scope.variable_scope(name, 'streaming_concat', [values]):
+ # pylint: disable=invalid-slice-index
+ values_shape = values.get_shape()
+ if values_shape.dims is None:
+ raise ValueError('`values` must have known statically known rank')
+
+ ndim = len(values_shape)
+ if axis < 0:
+ axis += ndim
+ if not 0 <= axis < ndim:
+ raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
+
+ fixed_shape = [dim.value for n, dim in enumerate(values_shape)
+ if n != axis]
+ if any(value is None for value in fixed_shape):
+ raise ValueError('all dimensions of `values` other than the dimension to '
+ 'concatenate along must have statically known size')
+
+ # We move `axis` to the front of the internal array so assign ops can be
+ # applied to contiguous slices
+ init_size = 0 if max_size is None else max_size
+ init_shape = [init_size] + fixed_shape
+ array = _create_local('array', shape=init_shape, dtype=values.dtype)
+ size = _create_local('size', shape=[], dtype=dtypes.int32)
+
+ perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
+ valid_array = array[:size]
+ valid_array.set_shape([None] + fixed_shape)
+ value = array_ops.transpose(valid_array, perm, name='concat')
+
+ values_size = array_ops.shape(values)[axis]
+ if max_size is None:
+ batch_size = values_size
+ else:
+ batch_size = math_ops.minimum(values_size, max_size - size)
+
+ perm = [axis] + [n for n in range(ndim) if n != axis]
+ batch_values = array_ops.transpose(values, perm)[:batch_size]
+
+ def reallocate():
+ next_size = _next_array_size(new_size)
+ next_shape = array_ops.pack([next_size] + fixed_shape)
+ new_value = array_ops.zeros(next_shape, dtype=values.dtype)
+ old_value = array.value()
+ assign_op = state_ops.assign(array, new_value, validate_shape=False)
+ with ops.control_dependencies([assign_op]):
+ copy_op = array[:size].assign(old_value[:size])
+ # return value needs to be the same dtype as no_op() for cond
+ with ops.control_dependencies([copy_op]):
+ return control_flow_ops.no_op()
+
+ new_size = size + batch_size
+ array_size = array_ops.shape_internal(array, optimize=False)[0]
+ maybe_reallocate_op = control_flow_ops.cond(
+ new_size > array_size, reallocate, control_flow_ops.no_op)
+ with ops.control_dependencies([maybe_reallocate_op]):
+ append_values_op = array[size:new_size].assign(batch_values)
+ with ops.control_dependencies([append_values_op]):
+ update_op = size.assign(new_size)
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, value)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return value, update_op
+ # pylint: enable=invalid-slice-index
+
+
def aggregate_metrics(*value_update_tuples):
"""Aggregates the metric value tensors and update ops into two lists.