diff options
author | 2016-02-11 13:23:00 -0800 | |
---|---|---|
committer | 2016-02-11 17:13:15 -0800 | |
commit | c4801e3624dec02091009b40dd9c7e28aed526b2 (patch) | |
tree | 6ed06452543db7e1d4046dac8728f5e634e4ad4a | |
parent | 88b0cb44a468ca8c26b20f008d3414a1474a3f8e (diff) |
Clean up saturate_cast, test, and move to tf.saturate_cast
Change: 114470777
m--------- | google/protobuf | 0 | ||||
-rw-r--r-- | tensorflow/contrib/layers/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cast_op_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops.py | 36 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 29 | ||||
-rw-r--r-- | tensorflow/python/platform/__init__.py | 20 |
8 files changed, 77 insertions, 34 deletions
diff --git a/google/protobuf b/google/protobuf -Subproject d2c7fe6bc5d28b225f6202684574fe4ef9e3a3a +Subproject caf1fb7197ee94c07108fc7cfbca07432b185a2 diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 1a09c550bb..f97bc87d04 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -30,6 +30,7 @@ py_test( "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -42,6 +43,7 @@ py_test( "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -54,6 +56,7 @@ py_test( "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -66,6 +69,7 @@ py_test( "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -78,6 +82,7 @@ py_test( "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index efcc7f11b5..47745815cb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -753,7 +753,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@re2//:re2", - ":protos_cc", + ":android_proto_lib", "//third_party/eigen3", ], ) diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py index 3d9e9c2eaf..3a0fe60344 100644 --- a/tensorflow/python/kernel_tests/cast_op_test.py +++ b/tensorflow/python/kernel_tests/cast_op_test.py @@ -175,5 +175,23 @@ class SparseTensorCastTest(tf.test.TestCase): self.assertAllEqual(st_cast.shape.eval(), [3]) +class SaturateCastTest(tf.test.TestCase): + + def testSaturate(self): + in_types = tf.float32, + out_types = tf.int8, tf.uint8, tf.int16, tf.float32 + with self.test_session() as sess: + for in_type in in_types: + for out_type in out_types: + lo, hi = in_type.min, in_type.max + x = tf.constant([lo, lo + 1, lo // 2, hi // 2, hi - 1, hi], + dtype=in_type) + y = tf.saturate_cast(x, dtype=out_type) + self.assertEqual(y.dtype, out_type) + x, y = sess.run([x, y]) + correct = np.maximum(out_type.min, np.minimum(out_type.max, x)) + self.assertAllEqual(correct, y) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 21deac53f7..a6ab03b3d6 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -25,6 +25,7 @@ types in your graph. @@to_int32 @@to_int64 @@cast +@@saturate_cast ## Shapes and Shaping diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index d88dbb80f2..8e7e9a7ce0 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -828,36 +828,6 @@ def _ImageEncodeShape(op): return [tensor_shape.scalar()] -def saturate_cast(image, dtype): - """Performs a safe cast of image data to `dtype`. - - This function casts the data in image to `dtype`, without applying any - scaling. If there is a danger that image data would over or underflow in the - cast, this op applies the appropriate clamping before the cast. - - Args: - image: An image to cast to a different data type. - dtype: A `DType` to cast `image` to. - - Returns: - `image`, safely cast to `dtype`. - """ - clamped = image - - # When casting to a type with smaller representable range, clamp. - # Note that this covers casting to unsigned types as well. - if image.dtype.min < dtype.min and image.dtype.max > dtype.max: - clamped = clip_ops.clip_by_value(clamped, - math_ops.cast(dtype.min, image.dtype), - math_ops.cast(dtype.max, image.dtype)) - elif image.dtype.min < dtype.min: - clamped = math_ops.maximum(clamped, math_ops.cast(dtype.min, image.dtype)) - elif image.dtype.max > dtype.max: - clamped = math_ops.minimum(clamped, math_ops.cast(dtype.max, image.dtype)) - - return math_ops.cast(clamped, dtype) - - def convert_image_dtype(image, dtype, saturate=False, name=None): """Convert `image` to `dtype`, scaling its values if needed. @@ -903,14 +873,14 @@ def convert_image_dtype(image, dtype, saturate=False, name=None): scaled = math_ops.div(image, scale) if saturate: - return saturate_cast(scaled, dtype) + return math_ops.saturate_cast(scaled, dtype) else: return math_ops.cast(scaled, dtype) else: # Scaling up, cast first, then scale. The scale will not map in.max to # out.max, but converting back and forth should result in no change. if saturate: - cast = saturate_cast(scaled, dtype) + cast = math_ops.saturate_cast(scaled, dtype) else: cast = math_ops.cast(image, dtype) scale = (scale_out + 1) // (scale_in + 1) @@ -931,7 +901,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None): scale = dtype.max + 0.5 # avoid rounding problems in the cast scaled = math_ops.mul(image, scale) if saturate: - return saturate_cast(scaled, dtype) + return math_ops.saturate_cast(scaled, dtype) else: return math_ops.cast(scaled, dtype) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 13158dde3b..e8337fad07 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -365,6 +365,35 @@ def cast(x, dtype, name=None): return gen_math_ops.cast(x, dtype, name=name) +def saturate_cast(value, dtype, name=None): + """Performs a safe saturating cast of `value` to `dtype`. + + This function casts the input to `dtype` without applying any scaling. If + there is a danger that values would over or underflow in the cast, this op + applies the appropriate clamping before the cast. + + Args: + value: A `Tensor`. + dtype: The desired output `DType`. + name: A name for the operation (optional). + + Returns: + `value` safely cast to `dtype`. + """ + # When casting to a type with smaller representable range, clamp. + # Note that this covers casting to unsigned types as well. + with ops.op_scope([value], name, "saturate_cast") as name: + value = ops.convert_to_tensor(value, name="value") + dtype = dtypes.as_dtype(dtype).base_dtype + if value.dtype.min < dtype.min: + value = maximum(value, ops.convert_to_tensor( + dtype.min, dtype=value.dtype, name="min")) + if value.dtype.max > dtype.max: + value = minimum(value, ops.convert_to_tensor( + dtype.max, dtype=value.dtype, name="max")) + return cast(value, dtype, name=name) + + def to_float(x, name="ToFloat"): """Casts a tensor to type `float32`. diff --git a/tensorflow/python/platform/__init__.py b/tensorflow/python/platform/__init__.py new file mode 100644 index 0000000000..aee1acdd46 --- /dev/null +++ b/tensorflow/python/platform/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""DEPRECATED: Setup system-specific platform environment for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + |