diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-01-21 15:10:17 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-21 18:00:23 -0800 |
commit | 6542307303f8b84b13fa20d49633330d22ae164f (patch) | |
tree | 1be607654ca4260012c78a104f6ea68297da9de8 | |
parent | b481783fe0e00a86f6feb20a8dcad5fc4fc936a4 (diff) |
Adding histogram_ops module containing one Op: histogram_fixed_width, which updates a histogram Variable with new_values.
Change: 112728652
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/histogram_ops.py | 94 | ||||
-rw-r--r-- | tensorflow/python/ops/histogram_ops_test.py | 80 |
3 files changed, 175 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index de7c6d2453..5ade3d8c17 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -699,6 +699,7 @@ py_library( "ops/gen_string_ops.py", "ops/gen_summary_ops.py", "ops/gradients.py", + "ops/histogram_ops.py", "ops/image_grad.py", "ops/image_ops.py", "ops/init_ops.py", diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py new file mode 100644 index 0000000000..f15d1c5cad --- /dev/null +++ b/tensorflow/python/ops/histogram_ops.py @@ -0,0 +1,94 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operations for histograms.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +def histogram_fixed_width(hist, + new_values, + value_range, + use_locking=False, + name='histogram_fixed_width'): + """Update histogram Variable with new values. + + This Op fills histogram with counts of values falling within fixed-width, + half-open bins. + + Args: + hist: 1-D mutable `Tensor`, e.g. a `Variable`. + new_values: Numeric `Tensor`. + value_range: Shape [2] `Tensor`. new_values <= value_range[0] will be + mapped to hist[0], values >= value_range[1] will be mapped to hist[-1]. + Must be same dtype as new_values. + use_locking: Boolean. + If `True`, use locking during the operation (optional). + name: A name for this operation (optional). + + Returns: + An op that updates `hist` with `new_values` when evaluated. + + Examples: + ```python + # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + nbins = 5 + value_range = [0.0, 5.0] + new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + + with tf.default_session() as sess: + hist = variables.Variable(array_ops.zeros(nbins, dtype=tf.int32)) + hist_update = histogram_ops.histogram_fixed_width(hist, new_values, + value_range) + variables.initialize_all_variables().run() + sess.run(hist_update) => [2, 1, 1, 0, 2] + ``` + """ + with ops.op_scope([hist, new_values, value_range], name) as scope: + new_values = ops.convert_to_tensor(new_values, name='new_values') + value_range = ops.convert_to_tensor(value_range, name='value_range') + dtype = hist.dtype + + # Map tensor values that fall within value_range to [0, 1]. + scaled_values = math_ops.truediv(new_values - value_range[0], + value_range[1] - value_range[0], + name='scaled_values') + nbins = math_ops.cast(hist.get_shape()[0], scaled_values.dtype) + + # map tensor values within the open interval value_range to {0,.., nbins-1}, + # values outside the open interval will be zero or less, or nbins or more. + indices = math_ops.floor(nbins * scaled_values, name='indices') + + # Clip edge cases (e.g. value = value_range[1]) or "outliers." + indices = math_ops.cast( + clip_ops.clip_by_value(indices, 0, nbins - 1), dtypes.int32) + + # Dummy vector to scatter. + # TODO(langmore) Replace non-ideal creation of large dummy vector once an + # alternative to scatter is available. + updates = array_ops.ones(indices.get_shape()[0], dtype=dtype) + return state_ops.scatter_add(hist, + indices, + updates, + use_locking=use_locking, + name=scope) diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py new file mode 100644 index 0000000000..3b5ca3c748 --- /dev/null +++ b/tensorflow/python/ops/histogram_ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for tensorflow.ops.histogram_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import histogram_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class HistogramFixedWidthTest(test_util.TensorFlowTestCase): + + def test_one_update_on_constant_input(self): + # Bins will be: + # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + nbins = 5 + value_range = [0.0, 5.0] + new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + expected_bin_counts = [2, 1, 1, 0, 2] + with self.test_session() as sess: + hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32)) + hist_update = histogram_ops.histogram_fixed_width(hist, new_values, + value_range) + variables.initialize_all_variables().run() + self.assertTrue(hist.dtype.is_compatible_with(hist_update.dtype)) + updated_hist_array = sess.run(hist_update) + + # The new updated_hist_array is returned by the updating op. + self.assertAllClose(expected_bin_counts, updated_hist_array) + + # hist should contain updated values, but eval() should not change it. + self.assertAllClose(expected_bin_counts, hist.eval()) + self.assertAllClose(expected_bin_counts, hist.eval()) + + def test_two_updates_on_constant_input(self): + # Bins will be: + # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + nbins = 5 + value_range = [0.0, 5.0] + new_values_1 = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + new_values_2 = [1.5, 4.5, 4.5, 4.5, 0.0, 0.0] + expected_bin_counts_1 = [2, 1, 1, 0, 2] + expected_bin_counts_2 = [4, 2, 1, 0, 5] + with self.test_session() as sess: + hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32)) + new_values = array_ops.placeholder(dtypes.float32, shape=[6]) + hist_update = histogram_ops.histogram_fixed_width(hist, new_values, + value_range) + variables.initialize_all_variables().run() + updated_hist_array = sess.run(hist_update, + feed_dict={new_values: new_values_1}) + self.assertAllClose(expected_bin_counts_1, updated_hist_array) + self.assertAllClose(expected_bin_counts_1, hist.eval()) + + updated_hist_array = sess.run(hist_update, + feed_dict={new_values: new_values_2}) + self.assertAllClose(expected_bin_counts_2, updated_hist_array) + self.assertAllClose(expected_bin_counts_2, hist.eval()) + + +if __name__ == '__main__': + googletest.main() |