diff options
author | 2018-08-23 14:13:12 -0700 | |
---|---|---|
committer | 2018-08-23 14:21:59 -0700 | |
commit | c133ef468b702c728dc6b74047129eb742fff5c5 (patch) | |
tree | d6fa6ee9f3c9f6750496b1e4cb3a180a03d0130f /tensorflow | |
parent | b91f904112914c7ca89f4d3c2839bed258776e78 (diff) |
[TF:XLA] Add TensorFlow operators that wrap most HLO operators.
PiperOrigin-RevId: 209997425
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/compiler/tests/xla_ops_test.py | 301 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc | 92 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc | 115 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc | 101 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc | 65 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc | 105 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc | 102 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc | 147 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/ops/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/ops/xla_ops.cc | 192 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/python/xla.py | 336 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 4 |
16 files changed, 1547 insertions, 76 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 235bef07b3..94e08b6efe 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1191,3 +1191,19 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) + +tf_xla_py_test( + name = "xla_ops_test", + size = "small", + srcs = ["xla_ops_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py new file mode 100644 index 0000000000..b2f026df6c --- /dev/null +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA op wrappers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected, + equality_fn=None): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def testAdd(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.add, + args=(np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype)), + expected=np.array([5, 7, 9], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(0,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 9], [14, 15]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(1,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 13], [10, 15]], dtype=dtype)) + + def testBroadcast(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.broadcast(x, (7, 42)), + args=(v,), + expected=np.tile(v, (7, 42, 1, 1))) + + def testShiftRightLogical(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + + def testShiftRightArithmetic(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([-1, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) + + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, + xla_data_pb2.PrecisionConfigProto.HIGH, + xla_data_pb2.PrecisionConfigProto.HIGHEST) + + @parameterized.parameters(*PRECISION_VALUES) + def testConv(self, precision): + for dtype in set(self.float_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + def conv_1d_fn(lhs, rhs): + dnums = xla_data_pb2.ConvolutionDimensionNumbers() + num_spatial_dims = 1 + dnums.input_batch_dimension = 0 + dnums.input_feature_dimension = 1 + dnums.output_batch_dimension = 0 + dnums.output_feature_dimension = 1 + dnums.kernel_output_feature_dimension = 0 + dnums.kernel_input_feature_dimension = 1 + dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.conv( + lhs, + rhs, + window_strides=(1,), + padding=((2, 1),), + lhs_dilation=(1,), + rhs_dilation=(2,), + dimension_numbers=dnums) + + self._assertOpOutputMatchesExpected( + conv_1d_fn, + args=( + np.array([[[3, 4, 5, 6]]], dtype=dtype), + np.array([[[-2, -3]]], dtype=dtype), + ), + expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype)) + + @parameterized.parameters(*PRECISION_VALUES) + def testDotGeneral(self, precision): + for dtype in self.float_types: + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.dot_general( + lhs, + rhs, + dimension_numbers=dnums, + precision_config=precision_config) + + lhs = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=dtype) + rhs = np.array( + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], dtype=dtype) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array( + [ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=dtype)) + + def testNeg(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.neg, + args=(np.array([1, 2, 3], dtype=dtype),), + expected=np.array([-1, -2, -3], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 0]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), + expected=np.array( + [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], + dtype=dtype)) + + def testReduce(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + def sum_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4])) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([12, 15, 18, 21], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([6, 22, 38], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0, 1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=dtype(66)) + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + def mul_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + mul_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([0, 45, 120, 231], dtype=dtype)) + + def testSelectAndScatter(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def add_scatter(x, y): + return x + y + + @function.Defun(dtype, dtype) + def ge_select(x, y): + return x >= y + + def test_fn(operand, source): + return xla.select_and_scatter( + operand, + window_dimensions=[2, 3, 1, 1], + window_strides=[2, 2, 1, 1], + padding=[[0, 0]] * 4, + source=source, + init_value=0, + select=ge_select, + scatter=add_scatter) + + self._assertOpOutputMatchesExpected( + test_fn, + args=(np.array( + [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6], + [0, 6, 2, 10, 2]], + dtype=dtype).reshape((4, 5, 1, 1)), + np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))), + expected=np.array( + [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0], + [0, 0, 0, 1, 0]], + dtype=dtype).reshape((4, 5, 1, 1))) + + def testTranspose(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8bbcff5f58..c1438f893f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -101,6 +101,12 @@ tf_kernel_library( "unary_ops.cc", "unpack_op.cc", "variable_ops.cc", + "xla_broadcast_helper_op.cc", + "xla_conv_op.cc", + "xla_dot_op.cc", + "xla_pad_op.cc", + "xla_reduce_op.cc", + "xla_select_and_scatter_op.cc", ], hdrs = [ "index_ops.h", @@ -110,6 +116,7 @@ tf_kernel_library( ":if_op", ":while_op", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index b11a4ce36d..8102faad28 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel { explicit ReduceWindowOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_dimensions", &window_dimensions_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_strides", &window_strides_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); const DataType dtype = context->input_type(0); + std::vector<int64> window_dimensions; + std::vector<int64> window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + const int rank = input_shape.dims(); - OP_REQUIRES(context, rank == window_dimensions_.size(), + OP_REQUIRES(context, rank == window_dimensions.size(), errors::InvalidArgument( "The size of window_dimensions must be equal to the input " "rank (", - window_dimensions_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == window_strides_.size(), + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), errors::InvalidArgument( "The size of window_strides must be equal to the input " "rank (", - window_strides_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_low_.size(), - errors::InvalidArgument( - "The size of padding_low must be equal to the input " - "rank (", - padding_low_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_high_.size(), - errors::InvalidArgument( - "The size of padding_high must be equal to the input " - "rank (", - padding_high_.size(), " vs. ", rank, ")")); - - xla::XlaBuilder* builder = context->builder(); + window_strides.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel { compile_options.use_tuple_arg = false; compile_options.resolve_compile_time_constants = false; compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; XlaCompiler::CompilationResult reducer; OP_REQUIRES_OK(context, context->compiler()->CompileFunction( compile_options, *computation_, @@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel { xla::Shape scalar_shape; OP_REQUIRES_OK(context, TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); OP_REQUIRES(context, - xla::ShapeUtil::Compatible( - reducer.xla_output_shape, - xla::ShapeUtil::MakeTupleShape({scalar_shape})), + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, errors::InvalidArgument( - "Invalid output shape of ReduceWindow reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", - xla::ShapeUtil::HumanString(reducer.xla_output_shape))); - - // Wraps the reducer in a computation that unpacks the output tuple. - xla::XlaComputation wrapper; - { - std::unique_ptr<xla::XlaBuilder> cb = - builder->CreateSubBuilder("wrapper"); - auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); - auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); - auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); - xla::GetTupleElement(outputs, 0); - xla::StatusOr<xla::XlaComputation> result = cb->Build(); - OP_REQUIRES_OK(context, result.status()); - wrapper = std::move(result.ValueOrDie()); - } - - std::vector<std::pair<int64, int64>> padding(rank); - for (int i = 0; i < rank; ++i) { - padding[i] = {padding_low_[i], padding_high_[i]}; + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get<int64>({i, 0}), + padding_literal.Get<int64>({i, 1})}; } xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( - context->Input(0), context->Input(1), wrapper, window_dimensions_, - window_strides_, padding); + context->Input(0), context->Input(1), *reducer.computation, + window_dimensions, window_strides, padding); context->SetOutput(0, output); } private: const NameAttrList* computation_; - std::vector<int64> window_dimensions_; - std::vector<int64> window_strides_; - std::vector<int64> padding_low_; - std::vector<int64> padding_high_; TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); }; -REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); +REGISTER_XLA_OP(Name("XlaReduceWindow") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + ReduceWindowOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc new file mode 100644 index 0000000000..412afeaaad --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +class XlaBroadcastHelperOp : public XlaOpKernel { + public: + explicit XlaBroadcastHelperOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp lhs = context->Input(0); + xla::XlaOp rhs = context->Input(1); + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); + const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; + const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; + + std::vector<int64> broadcast_dims; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims", + &broadcast_dims)); + if (broadcast_dims.empty()) { + OP_REQUIRES( + context, + lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || + rhs_shape.dims() == 0, + errors::InvalidArgument( + "If broadcast_dims is empty, both " + "arguments must have equal rank; " + "argument shapes, or at least one argument must be a scalar: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + return; + } + + OP_REQUIRES( + context, broadcast_dims.size() == min_rank_shape->dims(), + errors::InvalidArgument( + "broadcast_dims must have size equal to the smaller argument rank; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + std::vector<int64> sorted_broadcast_dims = broadcast_dims; + absl::c_sort(sorted_broadcast_dims); + std::set<int64> dims_set(broadcast_dims.begin(), broadcast_dims.end()); + OP_REQUIRES(context, + dims_set.size() == broadcast_dims.size() && + broadcast_dims == sorted_broadcast_dims, + errors::InvalidArgument( + "Duplicate or nonmonotonic dimension in broadcast_dims; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]")); + + std::vector<int64> broadcast_shape(max_rank_shape->dims(), 1LL); + for (int i = 0; i < broadcast_dims.size(); ++i) { + const int dim = broadcast_dims[i]; + OP_REQUIRES( + context, dim >= 0 && dim < broadcast_shape.size(), + errors::InvalidArgument( + "Invalid broadcast dimension (", dim, "); broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + broadcast_shape[dim] = min_rank_shape->dim_size(i); + } + xla::PrimitiveType type = context->input_xla_type(0); + xla::Shape broadcast_xla_shape = + xla::ShapeUtil::MakeShape(type, broadcast_shape); + if (broadcast_lhs) { + lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + } else { + rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + } + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + } + + private: + xla::DotDimensionNumbers dnums_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp); +}; + +REGISTER_XLA_OP( + Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"), + XlaBroadcastHelperOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc new file mode 100644 index 0000000000..8848623868 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaConvOp : public XlaOpKernel { + public: + explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + const TensorShape padding_shape = context->InputShape("padding"); + std::vector<int64> window_strides; + std::vector<int64> lhs_dilation; + std::vector<int64> rhs_dilation; + int64 feature_group_count; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation", + &lhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation", + &rhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar( + "feature_group_count", &feature_group_count)); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get<int64>({i, 0}), + padding_literal.Get<int64>({i, 1})}; + } + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::ConvGeneralDilated( + context->Input(0), context->Input(1), window_strides, padding, + lhs_dilation, rhs_dilation, dnums_, feature_group_count, + &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::ConvolutionDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); +}; + +REGISTER_XLA_OP(Name("XlaConv") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("lhs_dilation") + .CompileTimeConstInput("rhs_dilation") + .CompileTimeConstInput("feature_group_count") + .CompileTimeConstInput("padding"), + XlaConvOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc new file mode 100644 index 0000000000..2fed53e5c0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDotOp : public XlaOpKernel { + public: + explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1), + dnums_, &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::DotDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); +}; + +REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc new file mode 100644 index 0000000000..59502d83c7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaPadOp : public XlaOpKernel { + public: + explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape padding_value_shape = + context->InputShape("padding_value"); + + std::vector<int64> padding_low; + std::vector<int64> padding_high; + std::vector<int64> padding_interior; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low", + &padding_low)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high", + &padding_high)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "padding_interior", &padding_interior)); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape), + errors::InvalidArgument("padding_value must be a scalar")); + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == padding_low.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_interior.size(), + errors::InvalidArgument( + "The size of padding_interior must be equal to the input " + "rank (", + padding_interior.size(), " vs. ", rank, ")")); + + auto non_negative = [](int64 x) { return x >= 0; }; + OP_REQUIRES( + context, absl::c_all_of(padding_low, non_negative), + errors::InvalidArgument("padding_low must be non-negative, got [", + absl::StrJoin(padding_low, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_high, non_negative), + errors::InvalidArgument("padding_high must be non-negative, got [", + absl::StrJoin(padding_high, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_interior, non_negative), + errors::InvalidArgument("padding_interior must be non-negative, got [", + absl::StrJoin(padding_interior, ","), "]")); + + xla::PaddingConfig padding_config; + for (int i = 0; i < rank; ++i) { + auto* dim = padding_config.add_dimensions(); + dim->set_edge_padding_low(padding_low[i]); + dim->set_edge_padding_high(padding_high[i]); + dim->set_interior_padding(padding_interior[i]); + } + + xla::XlaOp output = + xla::Pad(context->Input("input"), context->Input("padding_value"), + padding_config); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp); +}; + +REGISTER_XLA_OP(Name("XlaPad") + .CompileTimeConstInput("padding_low") + .CompileTimeConstInput("padding_high") + .CompileTimeConstInput("padding_interior"), + XlaPadOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc new file mode 100644 index 0000000000..fc2425f37b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaReduceOp : public XlaOpKernel { + public: + explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_)); + OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce", + &dimensions_to_reduce_)); + std::set<int64> dims_set(dimensions_to_reduce_.begin(), + dimensions_to_reduce_.end()); + OP_REQUIRES( + context, dims_set.size() == dimensions_to_reduce_.size(), + errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " + "argument to XlaReduce")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape init_value_shape = context->InputShape("init_value"); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), + errors::InvalidArgument("init_value must be a scalar")); + + auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; + OP_REQUIRES(context, + rank >= dimensions_to_reduce_.size() && + absl::c_all_of(dimensions_to_reduce_, dim_in_range), + errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce")); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *reducer_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of XlaReduce reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + xla::XlaOp output = + xla::Reduce(context->Input("input"), context->Input("init_value"), + *reducer.computation, dimensions_to_reduce_); + context->SetOutput(0, output); + } + + private: + const NameAttrList* reducer_; + std::vector<int64> dimensions_to_reduce_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); +}; + +REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc new file mode 100644 index 0000000000..089776fcf7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSelectAndScatterOp : public XlaOpKernel { + public: + explicit XlaSelectAndScatterOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_)); + OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + std::vector<int64> window_dimensions; + std::vector<int64> window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides.size(), " vs. ", rank, ")")); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; + + // Build the select function. + XlaCompiler::Argument select_arg; + select_arg.kind = XlaCompiler::Argument::kParameter; + select_arg.type = dtype; + select_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult select; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *select_computation_, + {select_arg, select_arg}, &select)); + + xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {}); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(select.xla_output_shape, + select_output_shape), + errors::InvalidArgument( + "Invalid output shape of XlaSelectAndScatter select. Expected ", + xla::ShapeUtil::HumanString(select_output_shape), " got ", + xla::ShapeUtil::HumanString(select.xla_output_shape))); + + // Build the scatter function. + XlaCompiler::Argument scatter_arg; + scatter_arg.kind = XlaCompiler::Argument::kParameter; + scatter_arg.type = dtype; + scatter_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult scatter; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *scatter_computation_, + {scatter_arg, scatter_arg}, &scatter)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of scatter. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(scatter.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get<int64>({i, 0}), + padding_literal.Get<int64>({i, 1})}; + } + + xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding( + context->Input("operand"), *select.computation, window_dimensions, + window_strides, padding, context->Input("source"), + context->Input("init_value"), *scatter.computation); + context->SetOutput(0, output); + } + + private: + const NameAttrList* select_computation_; + const NameAttrList* scatter_computation_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp); +}; + +REGISTER_XLA_OP(Name("XlaSelectAndScatter") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + XlaSelectAndScatterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index ace6fd1d8e..4dce0a2102 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -11,6 +11,8 @@ cc_library( srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a59c77f5c3..2cd9ae799f 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,11 +13,97 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace { + +// Helper shape function for operators that return an output with the same rank +// as their first input. +Status UnchangedRank(shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); +} + +REGISTER_OP("XlaBroadcastHelper") + .Input("lhs: T") + .Input("rhs: T") + .Input("broadcast_dims: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("lhs_output: T") + .Output("rhs_output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Helper operator for performing XLA-style broadcasts + +Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +for binary operators. + +lhs: the LHS input tensor +rhs: the RHS input tensor +broadcast_dims: an XLA-style broadcast dimension specification +lhs_output: the broadcasted LHS tensor +rhs_output: the broadcasted RHS tensor +)doc"); + +REGISTER_OP("XlaConv") + .Input("lhs: T") + .Input("rhs: T") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("lhs_dilation: Tindices") + .Input("rhs_dilation: Tindices") + .Input("feature_group_count: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + +lhs: the input tensor +rhs: the kernel tensor +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +lhs_dilation: dilation to apply between input elements +rhs_dilation: dilation to apply between kernel elements +feature_group_count: number of feature groups for grouped convolution. +dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); + +REGISTER_OP("XlaDot") + .Input("lhs: T") + .Input("rhs: T") + .Attr("T: numbertype") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + +lhs: the LHS tensor +rhs: the RHS tensor +dimension_numbers: a serialized xla::DotDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") @@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); +REGISTER_OP("XlaPad") + .Input("input: T") + .Input("padding_value: T") + .Input("padding_low: Tindices") + .Input("padding_high: Tindices") + .Input("padding_interior: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Pad operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#pad +. + +input: A `Tensor` of type T. +padding_value: A scalar `Tensor` of type T. +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +padding_interior: the padding to apply between each input element. +output: A `Tensor` of type T. +)doc"); + REGISTER_OP("XlaRecv") .Output("tensor: dtype") .Attr("dtype: type") @@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel. shape: The shape of the tensor. )doc"); +REGISTER_OP("XlaReduce") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector<int64> dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set<int64> dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + if (rank < dimensions_to_reduce.size() || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce"); + } + c->set_output( + 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn(UnchangedRank) .Doc(R"doc( Wraps the XLA ReduceWindow operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . @@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction computation: a reducer function to apply window_dimensions: the shape of the window window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. +padding: the padding to apply at the start and end of each input dimensions +)doc"); + +REGISTER_OP("XlaSelectAndScatter") + .Input("operand: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("source: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("select: func") + .Attr("scatter: func") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA SelectAndScatter operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter +. + +operand: the input tensor +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +source: a tensor of values to scatter +init_value: a scalar representing the initial value for the output tensor +select: a selection function to apply +scatter: a scatter function to apply )doc"); REGISTER_OP("XlaSend") @@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 42b6292f79..69ca394360 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -28,5 +28,6 @@ py_library( srcs = ["xla.py"], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/xla:xla_data_proto_py", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 2fc47dffb8..3626de375e 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -15,11 +15,12 @@ """Experimental library that exposes XLA operations directly in TensorFlow. It is sometimes useful to be able to build HLO programs directly from -TensorFlow. This file provides Tensorflow operators that map as closely as -possible to HLO operators. +TensorFlow. This file provides Tensorflow operators that mirror the semantics of +HLO operators as closely as possible. -There is no promise of backward or forward compatibility for operators defined -in this module. +Note: There is no promise of backward or forward compatibility for operators +defined in this module. This is primarily because the underlying HLO operators +do not promise backward or forward compatibility. """ from __future__ import absolute_import @@ -27,11 +28,298 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + +# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing +# ops include: +# infeed/outfeed (available via tf.contrib.tpu) +# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) +# conditional +# gather/scatter +# collapse + +# This file reuses builtin names (following XLA's names, so we can call things +# like xla.max), so we capture the builtin versions here. +# pylint: disable=redefined-builtin +_max = max +_min = min +_slice = slice # pylint: disable=invalid-name + +constant = constant_op.constant + +# Unary operators. + +# For most arithmetic operators there is a TensorFlow operator +# that exactly corresponds to each XLA operator. Rather than defining +# XLA-specific variants, we reuse the corresponding TensorFlow operator. +# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 +# wrap every HLO operator, because that would allow us to be confident that the +# semantics match. + + +def _unary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def unary_op_wrapper(x, name=None): + return fn(x, name=name) + + return unary_op_wrapper + + +abs = _unary_op(math_ops.abs) +# TODO(phawkins): implement clz. +conj = _unary_op(math_ops.conj) +cos = _unary_op(math_ops.cos) +ceil = _unary_op(math_ops.ceil) +digamma = _unary_op(math_ops.digamma) +erf = _unary_op(math_ops.erf) +erfc = _unary_op(math_ops.erfc) +# TODO(phawkins): implement erfinv +exp = _unary_op(math_ops.exp) +expm1 = _unary_op(math_ops.expm1) +floor = _unary_op(math_ops.floor) +imag = _unary_op(math_ops.imag) +is_finite = _unary_op(math_ops.is_finite) +lgamma = _unary_op(math_ops.lgamma) +log = _unary_op(math_ops.log) +log1p = _unary_op(math_ops.log1p) +logical_not = _unary_op(math_ops.logical_not) +neg = _unary_op(math_ops.neg) +real = _unary_op(math_ops.real) +# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for +# numbers halfway between two integers. +round = _unary_op(math_ops.round) +sin = _unary_op(math_ops.sin) +sign = _unary_op(math_ops.sign) +tanh = _unary_op(math_ops.tanh) + +# Binary operators + +# The main difference between TensorFlow and XLA binary ops is the broadcasting +# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA +# requires an explicit specification of which dimensions to broadcast if the +# arguments have different ranks. + + +def _broadcasting_binary_op(fn): + """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" + + def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): + """Inner wrapper function.""" + broadcast_dims = broadcast_dims or [] + broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) + # Rather than relying on having static shape information in the TensorFlow + # graph, we use an XlaBroadcastHelper op that can compute the correct shapes + # at JIT compilation time. + x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) + return fn(x, y, name=name) + + return broadcasting_binary_op_wrapper + + +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + dtypes.int8: dtypes.uint8, + dtypes.int16: dtypes.uint16, + dtypes.int32: dtypes.uint32, + dtypes.int64: dtypes.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = { + dtypes.uint8: dtypes.int8, + dtypes.uint16: dtypes.int16, + dtypes.uint32: dtypes.int32, + dtypes.uint64: dtypes.int64, +} + + +def _shift_right_logical_helper(x, y, name=None): + """Performs an integer right logical shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + signed = dtype in _SIGNED_TO_UNSIGNED_TABLE + if signed: + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] + x = math_ops.cast(x, unsigned_dtype) + y = math_ops.cast(y, unsigned_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if signed: + output = math_ops.cast(output, dtype) + return output + + +def _shift_right_arithmetic_helper(x, y, name=None): + """Performs an integer right arithmetic shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE + if unsigned: + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] + x = math_ops.cast(x, signed_dtype) + y = math_ops.cast(y, signed_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if unsigned: + output = math_ops.cast(output, dtype) + return output + + +add = _broadcasting_binary_op(math_ops.add) +sub = _broadcasting_binary_op(math_ops.sub) +mul = _broadcasting_binary_op(math_ops.mul) +div = _broadcasting_binary_op(math_ops.div) +rem = _broadcasting_binary_op(gen_math_ops.mod) +max = _broadcasting_binary_op(math_ops.maximum) +min = _broadcasting_binary_op(math_ops.minimum) +atan2 = _broadcasting_binary_op(math_ops.atan2) +complex = _broadcasting_binary_op(math_ops.complex) +logical_and = _broadcasting_binary_op(math_ops.logical_and) +logical_or = _broadcasting_binary_op(math_ops.logical_or) +logical_xor = _broadcasting_binary_op(math_ops.logical_xor) +eq = _broadcasting_binary_op(math_ops.equal) +ne = _broadcasting_binary_op(math_ops.not_equal) +ge = _broadcasting_binary_op(math_ops.greater_equal) +gt = _broadcasting_binary_op(math_ops.greater) +le = _broadcasting_binary_op(math_ops.less_equal) +lt = _broadcasting_binary_op(math_ops.less) +pow = _broadcasting_binary_op(math_ops.pow) +shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) +shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) +shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) + + +def _binary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def binary_op_wrapper(x, y, name=None): + return fn(x, y, name=name) + + return binary_op_wrapper + + +transpose = _binary_op(array_ops.transpose) +rev = _binary_op(array_ops.reverse) + +bitcast_convert_type = array_ops.bitcast + + +def broadcast(x, dims, name=None): + x = ops.convert_to_tensor(x) + shape = array_ops.concat( + [constant_op.constant(dims), + array_ops.shape(x)], axis=0) + return array_ops.broadcast_to(x, shape, name=name) + + +def clamp(a, x, b, name=None): + return min(max(a, x, name=name), b, name=name) + + +concatenate = array_ops.concat + + +def conv(lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count=1, + precision_config=None, + name=None): + """Wraps the XLA ConvGeneralDilated operator. + + ConvGeneralDilated is the most general form of XLA convolution and is + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + + Args: + lhs: the input tensor + rhs: the kernel tensor + window_strides: the inter-window strides + padding: the padding to apply at the start and end of each input dimensions + lhs_dilation: dilation to apply between input elements + rhs_dilation: dilation to apply between kernel elements + dimension_numbers: a `ConvolutionDimensionNumbers` proto. + feature_group_count: number of feature groups for grouped convolution. + precision_config: a `PrecisionConfigProto` proto. + name: an optional name for the operator + + Returns: + A tensor representing the output of the convolution. + """ + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_conv( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +convert_element_type = math_ops.cast + + +def dot(lhs, rhs, name=None): + return math_ops.tensordot(lhs, rhs, axes=1, name=name) + + +def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_dot( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +def dynamic_slice(x, starts, sizes, name=None): + # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not + # a compile-time constant. This doesn't exactly mimic the semantics of dynamic + # slice if the slice is out of bounds. + return array_ops.slice(x, starts, sizes, name=name) -# TODO(phawkins): provide wrappers for all XLA operators. dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +# TODO(phawkins): generalize tf.pad to support interior padding, and then remove +# the XLA-specific pad operator. +pad = gen_xla_ops.xla_pad + + +def random_normal(mu, sigma, dims, name=None): + mu = ops.convert_to_tensor(mu) + return random_ops.random_normal( + dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) + + +def random_uniform(minval, maxval, dims, name=None): + minval = ops.convert_to_tensor(minval) + return random_ops.random_uniform( + dims, minval, maxval, dtype=minval.dtype, name=name) + + +recv = gen_xla_ops.xla_recv +reduce = gen_xla_ops.xla_reduce + def reduce_window(operand, init, @@ -61,22 +349,38 @@ def reduce_window(operand, """ window_strides = window_strides or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) - padding_low = [x for (x, _) in padding] - padding_high = [y for (_, y) in padding] return gen_xla_ops.xla_reduce_window( - operand, - init, - reducer, - window_dimensions, - window_strides, - padding_low, - padding_high, + input=operand, + init_value=init, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + computation=reducer, name=name) -recv = gen_xla_ops.xla_recv +def reshape(x, new_sizes, dimensions=None, name=None): + if dimensions is not None: + x = array_ops.transpose(x, dimensions) + x = array_ops.reshape(x, new_sizes, name=name) + return x + + +def select(condition, x, y, name=None): + return array_ops.where(condition, x, y, name) + + +select_and_scatter = gen_xla_ops.xla_select_and_scatter send = gen_xla_ops.xla_send -sort = gen_xla_ops.xla_sort +def slice(x, start_dims, limit_dims, strides): + spec = [ + _slice(start, limit, stride) + for (start, limit, stride) in zip(start_dims, limit_dims, strides) + ] + return x[tuple(spec)] + + +sort = gen_xla_ops.xla_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 31a41f8719..9e8f5f2a1a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -99,6 +99,25 @@ Status XlaOpKernelContext::ConstantInput(int index, index, context_->input(index).shape().dim_sizes(), constant_literal); } +static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context, + StringPiece name) { + int start, stop; + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + return start; +} + +Status XlaOpKernelContext::ConstantInput(StringPiece name, + xla::Literal* constant_literal) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInput(index, constant_literal); +} + Status XlaOpKernelContext::ConstantInputReshaped( int index, gtl::ArraySlice<int64> new_dims, xla::Literal* constant_literal) { @@ -246,6 +265,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, + int64* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntScalar(index, out); +} + Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); @@ -280,6 +305,12 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, + std::vector<int64>* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntVector(index, out); +} + Status XlaOpKernelContext::ConstantInputReshapedToIntVector( int index, std::vector<int64>* out) { xla::Literal literal; @@ -313,6 +344,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, + xla::Literal* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsInt64Literal(index, out); +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3f21a2bf41..3e26ba4f01 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -106,6 +106,7 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); + Status ConstantInput(StringPiece name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input @@ -117,12 +118,14 @@ class XlaOpKernelContext { // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); + Status ConstantInputAsIntScalar(StringPiece name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector<int64>* out); + Status ConstantInputAsIntVector(StringPiece name, std::vector<int64>* out); // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. @@ -130,6 +133,7 @@ class XlaOpKernelContext { // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); |