aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-13 09:07:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-13 09:11:24 -0800
commit903df63d5556f816685ce6b75c2216fc760d6b47 (patch)
tree7520fb6376a7aa8d13a9b37d6852cf99e8557e63 /tensorflow/compiler
parent1fc821602d69e5812b854a61f09f163ce549641b (diff)
TF to XLA compiler to support FakeQuantWithMinMaxVars/Args.
PiperOrigin-RevId: 185538228
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/BUILD11
-rw-r--r--tensorflow/compiler/tests/fake_quant_ops_test.py452
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc289
4 files changed, 753 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index acd9cf7bee..b49d2ca961 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -815,6 +815,17 @@ tf_library(
tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"],
)
+tf_xla_py_test(
+ name = "fake_quant_ops_test",
+ size = "medium",
+ srcs = ["fake_quant_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py
new file mode 100644
index 0000000000..dfe9400ef0
--- /dev/null
+++ b/tensorflow/compiler/tests/fake_quant_ops_test.py
@@ -0,0 +1,452 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.platform import googletest
+
+
+class FakeQuantWithMinMaxArgsTest(XLATestCase):
+ """Test cases for FakeQuantWithMinMaxArgs operation."""
+
+ # 8 bits, wide range.
+ def testOp_with8BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
+
+ def testOp_with8BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingUp(self):
+ self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
+
+ # 8 bits, narrow range.
+ def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
+
+ # 7 bits, wide range.
+ def testOp_with7BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
+
+ def testOp_with7BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingUp(self):
+ self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
+
+ # 7 bits, narrow range.
+ def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
+
+ def _TestOp(self, input_min, input_max, num_bits, narrow_range,
+ expected_nudged_input_min, expected_nudged_input_max,
+ expected_step):
+ inputs = np.array(
+ [
+ expected_nudged_input_min - expected_step,
+ expected_nudged_input_min - 0.01, expected_nudged_input_min,
+ expected_nudged_input_min + 0.01,
+ expected_nudged_input_min + expected_step - 0.01,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step + 0.01,
+ expected_nudged_input_max - 0.01, expected_nudged_input_max,
+ expected_nudged_input_max + 0.01,
+ expected_nudged_input_max + expected_step
+ ],
+ dtype=np.float32)
+ expected = np.array(
+ [
+ expected_nudged_input_min, expected_nudged_input_min,
+ expected_nudged_input_min, expected_nudged_input_min,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_max, expected_nudged_input_max,
+ expected_nudged_input_max, expected_nudged_input_max
+ ],
+ dtype=np.float32)
+
+ with self.test_session() as session:
+ with self.test_scope():
+ input_placeholder = array_ops.placeholder(
+ dtypes.float32, inputs.shape, name="inputs")
+ outputs = array_ops.fake_quant_with_min_max_args(
+ input_placeholder,
+ min=input_min,
+ max=input_max,
+ num_bits=num_bits,
+ narrow_range=narrow_range)
+ result = session.run(outputs, {input_placeholder: inputs})
+ self.assertAllCloseAccordingToType(
+ result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
+
+
+class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
+ """Test cases for FakeQuantWithMinMaxArgsGradient operation."""
+
+ # 8 bits, wide range.
+ def testOp_with8BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
+
+ def testOp_with8BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingUp(self):
+ self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
+
+ # 8 bits, narrow range.
+ def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
+
+ # 7 bits, wide range.
+ def testOp_with7BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
+
+ def testOp_with7BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingUp(self):
+ self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
+
+ # 7 bits, narrow range.
+ def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
+
+ def _TestOp(self, input_min, input_max, num_bits, narrow_range,
+ expected_nudged_input_min, expected_nudged_input_max,
+ expected_step):
+ inputs = np.array(
+ [
+ expected_nudged_input_min - expected_step,
+ expected_nudged_input_min - 0.01, expected_nudged_input_min,
+ expected_nudged_input_min + 0.01,
+ expected_nudged_input_min + expected_step - 0.01,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step + 0.01,
+ expected_nudged_input_max - 0.01, expected_nudged_input_max,
+ expected_nudged_input_max + 0.01,
+ expected_nudged_input_max + expected_step
+ ],
+ dtype=np.float32)
+ gradients = np.arange(1, len(inputs) + 1, dtype=np.float32)
+ expected_backprops = np.array(
+ [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
+ dtype=np.float32)
+
+ with self.test_session() as session:
+ with self.test_scope():
+ gradient_placeholder = array_ops.placeholder(
+ dtypes.float32, gradients.shape, name="gradients")
+ input_placeholder = array_ops.placeholder(
+ dtypes.float32, inputs.shape, name="inputs")
+ outputs = gen_array_ops.fake_quant_with_min_max_args_gradient(
+ gradient_placeholder,
+ input_placeholder,
+ min=input_min,
+ max=input_max,
+ num_bits=num_bits,
+ narrow_range=narrow_range)
+ backprops = session.run(outputs, {
+ gradient_placeholder: gradients,
+ input_placeholder: inputs
+ })
+ self.assertAllCloseAccordingToType(
+ backprops,
+ expected_backprops,
+ rtol=1e-3,
+ atol=1e-5,
+ bfloat16_rtol=0.03)
+
+
+class FakeQuantWithMinMaxVarsTest(XLATestCase):
+ """Test cases for FakeQuantWithMinMaxVars operation."""
+
+ # 8 bits, wide range.
+ def testOp_with8BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
+
+ def testOp_with8BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingUp(self):
+ self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
+
+ # 8 bits, narrow range.
+ def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
+
+ # 7 bits, wide range.
+ def testOp_with7BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
+
+ def testOp_with7BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingUp(self):
+ self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
+
+ # 7 bits, narrow range.
+ def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
+
+ def _TestOp(self, input_min, input_max, num_bits, narrow_range,
+ expected_nudged_input_min, expected_nudged_input_max,
+ expected_step):
+ inputs = np.array(
+ [
+ expected_nudged_input_min - expected_step,
+ expected_nudged_input_min - 0.01, expected_nudged_input_min,
+ expected_nudged_input_min + 0.01,
+ expected_nudged_input_min + expected_step - 0.01,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step + 0.01,
+ expected_nudged_input_max - 0.01, expected_nudged_input_max,
+ expected_nudged_input_max + 0.01,
+ expected_nudged_input_max + expected_step
+ ],
+ dtype=np.float32)
+ expected = np.array(
+ [
+ expected_nudged_input_min, expected_nudged_input_min,
+ expected_nudged_input_min, expected_nudged_input_min,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_max, expected_nudged_input_max,
+ expected_nudged_input_max, expected_nudged_input_max
+ ],
+ dtype=np.float32)
+
+ with self.test_session() as session:
+ with self.test_scope():
+ input_placeholder = array_ops.placeholder(
+ dtypes.float32, inputs.shape, name="inputs")
+ min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min")
+ max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max")
+ outputs = array_ops.fake_quant_with_min_max_vars(
+ input_placeholder,
+ min_placeholder,
+ max_placeholder,
+ num_bits=num_bits,
+ narrow_range=narrow_range)
+ result = session.run(
+ outputs, {
+ input_placeholder: inputs,
+ min_placeholder: input_min,
+ max_placeholder: input_max
+ })
+ self.assertAllCloseAccordingToType(
+ result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
+
+
+class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
+ """Test cases for FakeQuantWithMinMaxVarsGradient operation."""
+
+ # 8 bits, wide range.
+ def testOp_with8BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
+
+ def testOp_with8BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingUp(self):
+ self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
+
+ def testOp_with8BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
+
+ # 8 bits, narrow range.
+ def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
+
+ def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
+
+ # 7 bits, wide range.
+ def testOp_with7BitsNoScalingNoNudging(self):
+ self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
+
+ def testOp_with7BitsScalingAndNudgingDown(self):
+ self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingUp(self):
+ self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
+
+ def testOp_with7BitsScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
+
+ # 7 bits, narrow range.
+ def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
+ self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
+ self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
+ self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
+
+ def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
+ self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
+
+ def _TestOp(self, input_min, input_max, num_bits, narrow_range,
+ expected_nudged_input_min, expected_nudged_input_max,
+ expected_step):
+ inputs = np.array(
+ [
+ expected_nudged_input_min - expected_step,
+ expected_nudged_input_min - 0.01, expected_nudged_input_min,
+ expected_nudged_input_min + 0.01,
+ expected_nudged_input_min + expected_step - 0.01,
+ expected_nudged_input_min + expected_step,
+ expected_nudged_input_min + expected_step + 0.01,
+ expected_nudged_input_max - 0.01, expected_nudged_input_max,
+ expected_nudged_input_max + 0.01,
+ expected_nudged_input_max + expected_step
+ ],
+ dtype=np.float32)
+ gradients = np.arange(1, len(inputs) + 1, dtype=np.float32)
+ expected_backprops_wrt_input = np.array(
+ [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
+ dtype=np.float32)
+ expected_backprops_wrt_min = 1.0 + 2.0
+ expected_backprops_wrt_max = 10.0 + 11.0
+
+ with self.test_session() as session:
+ with self.test_scope():
+ gradient_placeholder = array_ops.placeholder(
+ dtypes.float32, gradients.shape, name="gradients")
+ input_placeholder = array_ops.placeholder(
+ dtypes.float32, inputs.shape, name="inputs")
+ min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min")
+ max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max")
+ outputs = array_ops.fake_quant_with_min_max_vars_gradient(
+ gradient_placeholder,
+ input_placeholder,
+ min_placeholder,
+ max_placeholder,
+ num_bits=num_bits,
+ narrow_range=narrow_range)
+ backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = session.run(
+ outputs, {
+ gradient_placeholder: gradients,
+ input_placeholder: inputs,
+ min_placeholder: input_min,
+ max_placeholder: input_max
+ })
+ self.assertAllCloseAccordingToType(
+ backprops_wrt_input,
+ expected_backprops_wrt_input,
+ rtol=1e-3,
+ atol=1e-5,
+ bfloat16_rtol=0.03)
+ self.assertAllCloseAccordingToType(
+ backprops_wrt_min,
+ expected_backprops_wrt_min,
+ rtol=1e-3,
+ atol=1e-5,
+ bfloat16_rtol=0.03)
+ self.assertAllCloseAccordingToType(
+ backprops_wrt_max,
+ expected_backprops_wrt_max,
+ rtol=1e-3,
+ atol=1e-5,
+ bfloat16_rtol=0.03)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index d83d576eda..d2fa933cf9 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -32,6 +32,7 @@ tf_kernel_library(
"dynamic_stitch_op.cc",
"elu_op.cc",
"extract_image_patches_op.cc",
+ "fake_quantize_ops.cc",
"fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
new file mode 100644
index 0000000000..453a32c494
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace {
+
+// Gymnastics with nudged zero point is to ensure that the real zero maps to
+// an integer, which is required for e.g. zero-padding in convolutional layers.
+void CpuNudge(const float min, const float max, const float quant_min,
+ const float quant_max, float* nudged_min, float* nudged_max,
+ float* scale) {
+ *scale = (max - min) / (quant_max - quant_min);
+
+ const float zero_point_from_min = quant_min - min / *scale;
+ float nudged_zero_point;
+ if (zero_point_from_min <= quant_min) {
+ nudged_zero_point = quant_min;
+ } else if (zero_point_from_min >= quant_max) {
+ nudged_zero_point = quant_max;
+ } else {
+ nudged_zero_point = std::round(zero_point_from_min);
+ }
+
+ *nudged_min = (quant_min - nudged_zero_point) * (*scale);
+ *nudged_max = (quant_max - nudged_zero_point) * (*scale);
+}
+
+// An XLA version of CpuNudge().
+void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
+ const xla::ComputationDataHandle& min,
+ const xla::ComputationDataHandle& max,
+ const float quant_min_value, const float quant_max_value,
+ xla::ComputationDataHandle* nudged_min,
+ xla::ComputationDataHandle* nudged_max,
+ xla::ComputationDataHandle* scale) {
+ *scale = b->Div(b->Sub(max, min),
+ XlaHelpers::FloatLiteral(b, data_type,
+ quant_max_value - quant_min_value));
+ xla::ComputationDataHandle quant_min =
+ XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
+ xla::ComputationDataHandle zero_point_from_min =
+ b->Sub(quant_min, b->Div(min, *scale));
+ xla::ComputationDataHandle quant_max =
+ XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
+ xla::ComputationDataHandle nudged_zero_point =
+ b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
+ b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
+ b->Round(zero_point_from_min)));
+ *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
+ *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
+}
+
+xla::ComputationDataHandle Quantize(
+ xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
+ const DataType data_type,
+ const xla::ComputationDataHandle& nudged_input_min,
+ const xla::ComputationDataHandle& nudged_input_max,
+ const xla::ComputationDataHandle& input_scale) {
+ xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
+ xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
+ xla::ComputationDataHandle half =
+ XlaHelpers::FloatLiteral(b, data_type, 0.5f);
+
+ xla::ComputationDataHandle clamped =
+ b->Clamp(nudged_input_min, input, nudged_input_max);
+ xla::ComputationDataHandle clamped_shifted =
+ b->Sub(clamped, nudged_input_min);
+ xla::ComputationDataHandle rounded =
+ b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
+ return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
+}
+
+class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
+ public:
+ explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ int num_bits;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
+ OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
+ errors::InvalidArgument("num_bits is out of range, expected "
+ "between 2 and 16, was: ",
+ num_bits));
+ bool narrow_range;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
+ quant_min_ = narrow_range ? 1 : 0;
+ quant_max_ = (1 << num_bits) - 1;
+
+ float input_min, input_max;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
+ CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
+ &nudged_input_max_, &input_scale_);
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationDataHandle input = ctx->Input(0);
+ const DataType data_type = ctx->input_type(0);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ xla::ComputationDataHandle nudged_input_min =
+ XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
+ xla::ComputationDataHandle nudged_input_max =
+ XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
+ xla::ComputationDataHandle input_scale =
+ XlaHelpers::FloatLiteral(b, data_type, input_scale_);
+ xla::ComputationDataHandle output = Quantize(
+ b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ float quant_min_;
+ float quant_max_;
+ float nudged_input_min_;
+ float nudged_input_max_;
+ float input_scale_;
+};
+
+REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
+
+class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
+ public:
+ explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ int num_bits;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
+ OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
+ errors::InvalidArgument("num_bits is out of range, expected "
+ "between 2 and 16, was: ",
+ num_bits));
+ bool narrow_range;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
+ const float quant_min = narrow_range ? 1 : 0;
+ const float quant_max = (1 << num_bits) - 1;
+
+ float input_min, input_max, scale;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
+ CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
+ &nudged_input_max_, &scale);
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationDataHandle gradient = ctx->Input(0);
+ const TensorShape gradient_shape = ctx->InputShape(0);
+ xla::ComputationDataHandle input = ctx->Input(1);
+ const DataType data_type = ctx->input_type(1);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ xla::ComputationDataHandle nudged_input_min =
+ XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
+ xla::ComputationDataHandle nudged_input_max =
+ XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
+
+ xla::ComputationDataHandle between_nudged_min_max =
+ b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
+ xla::ComputationDataHandle zeroes = b->Broadcast(
+ XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
+ xla::ComputationDataHandle output =
+ b->Select(between_nudged_min_max, gradient, zeroes);
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ float nudged_input_min_;
+ float nudged_input_max_;
+};
+
+REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
+ FakeQuantWithMinMaxArgsGradOp);
+
+class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
+ public:
+ explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ int num_bits;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
+ OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
+ errors::InvalidArgument("num_bits is out of range, expected "
+ "between 2 and 16, was: ",
+ num_bits));
+ bool narrow_range;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
+ quant_min_ = narrow_range ? 1 : 0;
+ quant_max_ = (1 << num_bits) - 1;
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationDataHandle input = ctx->Input(0);
+ const DataType data_type = ctx->input_type(0);
+ xla::ComputationDataHandle input_min = ctx->Input(1);
+ xla::ComputationDataHandle input_max = ctx->Input(2);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
+ XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
+ &nudged_input_min, &nudged_input_max, &input_scale);
+
+ xla::ComputationDataHandle output = Quantize(
+ b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ float quant_min_;
+ float quant_max_;
+};
+
+REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
+
+class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
+ public:
+ explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ int num_bits;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
+ OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
+ errors::InvalidArgument("num_bits is out of range, expected "
+ "between 2 and 16, was: ",
+ num_bits));
+ bool narrow_range;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
+ quant_min_ = narrow_range ? 1 : 0;
+ quant_max_ = (1 << num_bits) - 1;
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationDataHandle gradient = ctx->Input(0);
+ const TensorShape gradient_shape = ctx->InputShape(0);
+ xla::ComputationDataHandle input = ctx->Input(1);
+ const DataType data_type = ctx->input_type(1);
+ xla::ComputationDataHandle input_min = ctx->Input(2);
+ xla::ComputationDataHandle input_max = ctx->Input(3);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
+ XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
+ &nudged_input_min, &nudged_input_max, &input_scale);
+
+ xla::ComputationDataHandle between_nudged_min_max =
+ b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
+ xla::ComputationDataHandle zeroes =
+ b->Broadcast(zero, gradient_shape.dim_sizes());
+ xla::ComputationDataHandle output0 =
+ b->Select(between_nudged_min_max, gradient, zeroes);
+ ctx->SetOutput(0, output0);
+
+ xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
+ xla::ComputationDataHandle output1 =
+ b->ReduceAll(b->Select(below_min, gradient, zeroes), zero,
+ *ctx->GetOrCreateAdd(data_type));
+ ctx->SetOutput(1, output1);
+
+ xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
+ xla::ComputationDataHandle output2 =
+ b->ReduceAll(b->Select(above_max, gradient, zeroes), zero,
+ *ctx->GetOrCreateAdd(data_type));
+ ctx->SetOutput(2, output2);
+ }
+
+ private:
+ float quant_min_;
+ float quant_max_;
+};
+
+REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
+ FakeQuantWithMinMaxVarsGradOp);
+
+} // namespace
+} // namespace tensorflow