diff options
author | Stephan Hoyer <shoyer@google.com> | 2016-09-26 14:09:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-26 15:17:44 -0700 |
commit | e7b2114e0732a11cf610b388f92a46fa8f9da3c4 (patch) | |
tree | cbeef510a7659e3235895997b0c59ea02190e16a /tensorflow/contrib/metrics/python/ops/metric_ops.py | |
parent | 9751084da88a30b03eda84e17d88379fa16289ff (diff) |
Add tf.contrib.metrics.streaming_concat
This op is useful, on small datasets, for evaluating metrics that otherwise
cannot be updated incrementally.
Change: 134332647
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops.py')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 2c56898fb5..6c52f6820b 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -2745,6 +2745,135 @@ def streaming_mean_iou(predictions, return mean_iou, update_op +def _next_array_size(required_size, growth_factor=1.5): + """Calculate the next size for reallocating a dynamic array. + + Args: + required_size: number or tf.Tensor specifying required array capacity. + growth_factor: optional number or tf.Tensor specifying the growth factor + between subsequent allocations. + + Returns: + tf.Tensor with dtype=int32 giving the next array size. + """ + exponent = math_ops.ceil( + math_ops.log(math_ops.cast(required_size, dtypes.float32)) + / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) + return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) + + +def streaming_concat(values, + axis=0, + max_size=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Concatenate values along an axis across batches. + + The function `streaming_concat` creates two local variables, `array` and + `size`, that are used to store concatenated values. Internally, `array` is + used as storage for a dynamic array (if `maxsize` is `None`), which ensures + that updates can be run in amortized constant time. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that appends the values of a tensor and returns the + `value` of the concatenated tensors. + + This op allows for evaluating metrics that cannot be updated incrementally + using the same framework as other streaming metrics. + + Args: + values: tensor to concatenate. Rank and the shape along all axes other than + the axis to concatenate along must be statically known. + axis: optional integer axis to concatenate along. + max_size: optional integer maximum size of `value` along the given axis. + Once the maximum size is reached, further updates are no-ops. By default, + there is no maximum size: the array is resized as necessary. + metrics_collections: An optional list of collections that `value` + should be added to. + updates_collections: An optional list of collections `update_op` should be + added to. + name: An optional variable_scope name. + + Returns: + value: A tensor representing the concatenated values. + update_op: An operation that concatenates the next values. + + Raises: + ValueError: if `values` does not have a statically known rank, `axis` is + not in the valid range or the size of `values` is not statically known + along any axis other than `axis`. + """ + with variable_scope.variable_scope(name, 'streaming_concat', [values]): + # pylint: disable=invalid-slice-index + values_shape = values.get_shape() + if values_shape.dims is None: + raise ValueError('`values` must have known statically known rank') + + ndim = len(values_shape) + if axis < 0: + axis += ndim + if not 0 <= axis < ndim: + raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) + + fixed_shape = [dim.value for n, dim in enumerate(values_shape) + if n != axis] + if any(value is None for value in fixed_shape): + raise ValueError('all dimensions of `values` other than the dimension to ' + 'concatenate along must have statically known size') + + # We move `axis` to the front of the internal array so assign ops can be + # applied to contiguous slices + init_size = 0 if max_size is None else max_size + init_shape = [init_size] + fixed_shape + array = _create_local('array', shape=init_shape, dtype=values.dtype) + size = _create_local('size', shape=[], dtype=dtypes.int32) + + perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] + valid_array = array[:size] + valid_array.set_shape([None] + fixed_shape) + value = array_ops.transpose(valid_array, perm, name='concat') + + values_size = array_ops.shape(values)[axis] + if max_size is None: + batch_size = values_size + else: + batch_size = math_ops.minimum(values_size, max_size - size) + + perm = [axis] + [n for n in range(ndim) if n != axis] + batch_values = array_ops.transpose(values, perm)[:batch_size] + + def reallocate(): + next_size = _next_array_size(new_size) + next_shape = array_ops.pack([next_size] + fixed_shape) + new_value = array_ops.zeros(next_shape, dtype=values.dtype) + old_value = array.value() + assign_op = state_ops.assign(array, new_value, validate_shape=False) + with ops.control_dependencies([assign_op]): + copy_op = array[:size].assign(old_value[:size]) + # return value needs to be the same dtype as no_op() for cond + with ops.control_dependencies([copy_op]): + return control_flow_ops.no_op() + + new_size = size + batch_size + array_size = array_ops.shape_internal(array, optimize=False)[0] + maybe_reallocate_op = control_flow_ops.cond( + new_size > array_size, reallocate, control_flow_ops.no_op) + with ops.control_dependencies([maybe_reallocate_op]): + append_values_op = array[size:new_size].assign(batch_values) + with ops.control_dependencies([append_values_op]): + update_op = size.assign(new_size) + + if metrics_collections: + ops.add_to_collections(metrics_collections, value) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return value, update_op + # pylint: enable=invalid-slice-index + + def aggregate_metrics(*value_update_tuples): """Aggregates the metric value tensors and update ops into two lists. |