aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-02-11 13:23:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-11 17:13:15 -0800
commitc4801e3624dec02091009b40dd9c7e28aed526b2 (patch)
tree6ed06452543db7e1d4046dac8728f5e634e4ad4a
parent88b0cb44a468ca8c26b20f008d3414a1474a3f8e (diff)
Clean up saturate_cast, test, and move to tf.saturate_cast
Change: 114470777
m---------google/protobuf0
-rw-r--r--tensorflow/contrib/layers/BUILD5
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py18
-rw-r--r--tensorflow/python/ops/array_ops.py1
-rw-r--r--tensorflow/python/ops/image_ops.py36
-rw-r--r--tensorflow/python/ops/math_ops.py29
-rw-r--r--tensorflow/python/platform/__init__.py20
8 files changed, 77 insertions, 34 deletions
diff --git a/google/protobuf b/google/protobuf
-Subproject d2c7fe6bc5d28b225f6202684574fe4ef9e3a3a
+Subproject caf1fb7197ee94c07108fc7cfbca07432b185a2
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 1a09c550bb..f97bc87d04 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -30,6 +30,7 @@ py_test(
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
@@ -42,6 +43,7 @@ py_test(
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
@@ -54,6 +56,7 @@ py_test(
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
@@ -66,6 +69,7 @@ py_test(
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
@@ -78,6 +82,7 @@ py_test(
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index efcc7f11b5..47745815cb 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -753,7 +753,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"@re2//:re2",
- ":protos_cc",
+ ":android_proto_lib",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 3d9e9c2eaf..3a0fe60344 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -175,5 +175,23 @@ class SparseTensorCastTest(tf.test.TestCase):
self.assertAllEqual(st_cast.shape.eval(), [3])
+class SaturateCastTest(tf.test.TestCase):
+
+ def testSaturate(self):
+ in_types = tf.float32,
+ out_types = tf.int8, tf.uint8, tf.int16, tf.float32
+ with self.test_session() as sess:
+ for in_type in in_types:
+ for out_type in out_types:
+ lo, hi = in_type.min, in_type.max
+ x = tf.constant([lo, lo + 1, lo // 2, hi // 2, hi - 1, hi],
+ dtype=in_type)
+ y = tf.saturate_cast(x, dtype=out_type)
+ self.assertEqual(y.dtype, out_type)
+ x, y = sess.run([x, y])
+ correct = np.maximum(out_type.min, np.minimum(out_type.max, x))
+ self.assertAllEqual(correct, y)
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 21deac53f7..a6ab03b3d6 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -25,6 +25,7 @@ types in your graph.
@@to_int32
@@to_int64
@@cast
+@@saturate_cast
## Shapes and Shaping
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index d88dbb80f2..8e7e9a7ce0 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -828,36 +828,6 @@ def _ImageEncodeShape(op):
return [tensor_shape.scalar()]
-def saturate_cast(image, dtype):
- """Performs a safe cast of image data to `dtype`.
-
- This function casts the data in image to `dtype`, without applying any
- scaling. If there is a danger that image data would over or underflow in the
- cast, this op applies the appropriate clamping before the cast.
-
- Args:
- image: An image to cast to a different data type.
- dtype: A `DType` to cast `image` to.
-
- Returns:
- `image`, safely cast to `dtype`.
- """
- clamped = image
-
- # When casting to a type with smaller representable range, clamp.
- # Note that this covers casting to unsigned types as well.
- if image.dtype.min < dtype.min and image.dtype.max > dtype.max:
- clamped = clip_ops.clip_by_value(clamped,
- math_ops.cast(dtype.min, image.dtype),
- math_ops.cast(dtype.max, image.dtype))
- elif image.dtype.min < dtype.min:
- clamped = math_ops.maximum(clamped, math_ops.cast(dtype.min, image.dtype))
- elif image.dtype.max > dtype.max:
- clamped = math_ops.minimum(clamped, math_ops.cast(dtype.max, image.dtype))
-
- return math_ops.cast(clamped, dtype)
-
-
def convert_image_dtype(image, dtype, saturate=False, name=None):
"""Convert `image` to `dtype`, scaling its values if needed.
@@ -903,14 +873,14 @@ def convert_image_dtype(image, dtype, saturate=False, name=None):
scaled = math_ops.div(image, scale)
if saturate:
- return saturate_cast(scaled, dtype)
+ return math_ops.saturate_cast(scaled, dtype)
else:
return math_ops.cast(scaled, dtype)
else:
# Scaling up, cast first, then scale. The scale will not map in.max to
# out.max, but converting back and forth should result in no change.
if saturate:
- cast = saturate_cast(scaled, dtype)
+ cast = math_ops.saturate_cast(scaled, dtype)
else:
cast = math_ops.cast(image, dtype)
scale = (scale_out + 1) // (scale_in + 1)
@@ -931,7 +901,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None):
scale = dtype.max + 0.5 # avoid rounding problems in the cast
scaled = math_ops.mul(image, scale)
if saturate:
- return saturate_cast(scaled, dtype)
+ return math_ops.saturate_cast(scaled, dtype)
else:
return math_ops.cast(scaled, dtype)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 13158dde3b..e8337fad07 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -365,6 +365,35 @@ def cast(x, dtype, name=None):
return gen_math_ops.cast(x, dtype, name=name)
+def saturate_cast(value, dtype, name=None):
+ """Performs a safe saturating cast of `value` to `dtype`.
+
+ This function casts the input to `dtype` without applying any scaling. If
+ there is a danger that values would over or underflow in the cast, this op
+ applies the appropriate clamping before the cast.
+
+ Args:
+ value: A `Tensor`.
+ dtype: The desired output `DType`.
+ name: A name for the operation (optional).
+
+ Returns:
+ `value` safely cast to `dtype`.
+ """
+ # When casting to a type with smaller representable range, clamp.
+ # Note that this covers casting to unsigned types as well.
+ with ops.op_scope([value], name, "saturate_cast") as name:
+ value = ops.convert_to_tensor(value, name="value")
+ dtype = dtypes.as_dtype(dtype).base_dtype
+ if value.dtype.min < dtype.min:
+ value = maximum(value, ops.convert_to_tensor(
+ dtype.min, dtype=value.dtype, name="min"))
+ if value.dtype.max > dtype.max:
+ value = minimum(value, ops.convert_to_tensor(
+ dtype.max, dtype=value.dtype, name="max"))
+ return cast(value, dtype, name=name)
+
+
def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`.
diff --git a/tensorflow/python/platform/__init__.py b/tensorflow/python/platform/__init__.py
new file mode 100644
index 0000000000..aee1acdd46
--- /dev/null
+++ b/tensorflow/python/platform/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""DEPRECATED: Setup system-specific platform environment for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+