From c133ef468b702c728dc6b74047129eb742fff5c5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 23 Aug 2018 14:13:12 -0700 Subject: [TF:XLA] Add TensorFlow operators that wrap most HLO operators. PiperOrigin-RevId: 209997425 --- tensorflow/compiler/tests/BUILD | 16 ++ tensorflow/compiler/tests/xla_ops_test.py | 301 ++++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 tensorflow/compiler/tests/xla_ops_test.py (limited to 'tensorflow/compiler/tests') diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 235bef07b3..94e08b6efe 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1191,3 +1191,19 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) + +tf_xla_py_test( + name = "xla_ops_test", + size = "small", + srcs = ["xla_ops_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py new file mode 100644 index 0000000000..b2f026df6c --- /dev/null +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA op wrappers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected, + equality_fn=None): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def testAdd(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.add, + args=(np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype)), + expected=np.array([5, 7, 9], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(0,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 9], [14, 15]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(1,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 13], [10, 15]], dtype=dtype)) + + def testBroadcast(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.broadcast(x, (7, 42)), + args=(v,), + expected=np.tile(v, (7, 42, 1, 1))) + + def testShiftRightLogical(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + + def testShiftRightArithmetic(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([-1, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) + + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, + xla_data_pb2.PrecisionConfigProto.HIGH, + xla_data_pb2.PrecisionConfigProto.HIGHEST) + + @parameterized.parameters(*PRECISION_VALUES) + def testConv(self, precision): + for dtype in set(self.float_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + def conv_1d_fn(lhs, rhs): + dnums = xla_data_pb2.ConvolutionDimensionNumbers() + num_spatial_dims = 1 + dnums.input_batch_dimension = 0 + dnums.input_feature_dimension = 1 + dnums.output_batch_dimension = 0 + dnums.output_feature_dimension = 1 + dnums.kernel_output_feature_dimension = 0 + dnums.kernel_input_feature_dimension = 1 + dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.conv( + lhs, + rhs, + window_strides=(1,), + padding=((2, 1),), + lhs_dilation=(1,), + rhs_dilation=(2,), + dimension_numbers=dnums) + + self._assertOpOutputMatchesExpected( + conv_1d_fn, + args=( + np.array([[[3, 4, 5, 6]]], dtype=dtype), + np.array([[[-2, -3]]], dtype=dtype), + ), + expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype)) + + @parameterized.parameters(*PRECISION_VALUES) + def testDotGeneral(self, precision): + for dtype in self.float_types: + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.dot_general( + lhs, + rhs, + dimension_numbers=dnums, + precision_config=precision_config) + + lhs = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=dtype) + rhs = np.array( + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], dtype=dtype) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array( + [ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=dtype)) + + def testNeg(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.neg, + args=(np.array([1, 2, 3], dtype=dtype),), + expected=np.array([-1, -2, -3], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 0]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), + expected=np.array( + [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], + dtype=dtype)) + + def testReduce(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + def sum_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4])) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([12, 15, 18, 21], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([6, 22, 38], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0, 1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=dtype(66)) + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + def mul_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + mul_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([0, 45, 120, 231], dtype=dtype)) + + def testSelectAndScatter(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def add_scatter(x, y): + return x + y + + @function.Defun(dtype, dtype) + def ge_select(x, y): + return x >= y + + def test_fn(operand, source): + return xla.select_and_scatter( + operand, + window_dimensions=[2, 3, 1, 1], + window_strides=[2, 2, 1, 1], + padding=[[0, 0]] * 4, + source=source, + init_value=0, + select=ge_select, + scatter=add_scatter) + + self._assertOpOutputMatchesExpected( + test_fn, + args=(np.array( + [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6], + [0, 6, 2, 10, 2]], + dtype=dtype).reshape((4, 5, 1, 1)), + np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))), + expected=np.array( + [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0], + [0, 0, 0, 1, 0]], + dtype=dtype).reshape((4, 5, 1, 1))) + + def testTranspose(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + + +if __name__ == '__main__': + googletest.main() -- cgit v1.2.3