aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-20 03:14:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 03:16:50 -0700
commit90d084e0c42232043c186e66093b67800fb30fba (patch)
tree230d7998ea42af3efd59b0d25312eaa54efce5de /tensorflow/compiler/tests
parent9604413da7a27f5718bb88d407d13476dbef5b82 (diff)
[XLA:TF] Whitelist quantized types for CPU/GPU
These have the same behavior as unquantized types so we can just pass them through to XLA (which converts them to unquantized types). They're supposed to be used with special ops, none of which are currently implemented by XLA. Casting (without quantization) and basic math works fine though. These do not have a corresponding numpy type, so only tests using TF types will see them. PiperOrigin-RevId: 213781650
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/BUILD15
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl4
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/quantized_ops_test.py48
-rw-r--r--tensorflow/compiler/tests/xla_test.py13
5 files changed, 83 insertions, 11 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 97ed554171..3cf74fa788 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -978,7 +978,7 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
- tags = ["noasan"], # times out, http://b/78599043
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -1198,6 +1198,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "quantized_ops_test",
+ size = "small",
+ srcs = ["quantized_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index b8583c9bdb..1d3979b21b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -62,12 +62,12 @@ def tf_xla_py_test(
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
backend_tags += tf_cuda_tests_tags()
elif backend in plugins:
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95daab..a38e1edafe 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase):
indices_tf = constant_op.constant(indices)
gather_t = array_ops.gather(params, indices_tf)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- np_val = params_np[indices]
+ np_val = constant_op.constant(params_np[indices])
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
@@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant(2)
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, 2, axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, 2, axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant([0, 1, 0, 2])
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase):
params: params_np,
indices: indices_np
})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testHigherRank(self):
@@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase):
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
gather_value = sess.run(gather, feed_dict={tf_params: params})
- gather_np = np.take(params, indices, axis=axis)
+ gather_np = constant_op.constant(
+ np.take(params, indices, axis=axis), dtype)
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000000..80c338513b
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+ # Verify that quantized types can be clustered by XLA.
+ def testQuantizedTypeRoundtrip(self):
+ with self.cached_session() as session:
+ for dtype in self.quantized_tf_types:
+ in_values = np.array([1, 2, 3, 4, 5, 6])
+ expected = [[1, 2], [3, 4], [5, 6]]
+ with self.test_scope():
+ p = array_ops.placeholder(dtype=dtypes.int32)
+ x = math_ops.cast(p, dtype)
+ x = array_ops.reshape(x, [3, 2])
+
+ value = session.run(x, {p: in_values})
+ self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index df5c81243a..98a41981cf 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,9 +97,16 @@ class XLATestCase(test.TestCase):
])
self._numeric_tf_types = set(
self.int_tf_types | self._float_tf_types | self.complex_tf_types)
-
- self._all_types = set(
- [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ self.quantized_tf_types = set(
+ dtype for dtype in self._all_tf_types if dtype.is_quantized)
+
+ # Quantized types don't have a numpy equivalent, include them in
+ # all_tf_types but not in all_types.
+ # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+ # and remove all_types.
+ self._all_types = set(dtype.as_numpy_dtype
+ for dtype in self._all_tf_types
+ if not dtype.is_quantized)
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
self.signed_int_types = set(dtype.as_numpy_dtype
for dtype in self.int_tf_types