aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-08-23 14:13:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 14:21:59 -0700
commitc133ef468b702c728dc6b74047129eb742fff5c5 (patch)
treed6fa6ee9f3c9f6750496b1e4cb3a180a03d0130f /tensorflow
parentb91f904112914c7ca89f4d3c2839bed258776e78 (diff)
[TF:XLA] Add TensorFlow operators that wrap most HLO operators.
PiperOrigin-RevId: 209997425
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/BUILD16
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py301
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc92
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc115
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc101
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc65
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc105
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc102
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc147
-rw-r--r--tensorflow/compiler/tf2xla/ops/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc192
-rw-r--r--tensorflow/compiler/tf2xla/python/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py336
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc37
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h4
16 files changed, 1547 insertions, 76 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 235bef07b3..94e08b6efe 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1191,3 +1191,19 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
+
+tf_xla_py_test(
+ name = "xla_ops_test",
+ size = "small",
+ srcs = ["xla_ops_test.py"],
+ disabled_backends = ["cpu_ondemand"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
new file mode 100644
index 0000000000..b2f026df6c
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -0,0 +1,301 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA op wrappers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.compiler.tf2xla.python import xla
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ def _assertOpOutputMatchesExpected(self, op, args, expected,
+ equality_fn=None):
+ with self.test_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(*placeholders)
+ result = session.run(output, feeds)
+ if not equality_fn:
+ equality_fn = self.assertAllClose
+ equality_fn(result, expected, rtol=1e-3)
+
+ def testAdd(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.add,
+ args=(np.array([1, 2, 3], dtype=dtype),
+ np.array([4, 5, 6], dtype=dtype)),
+ expected=np.array([5, 7, 9], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(0,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 9], [14, 15]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(1,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 13], [10, 15]], dtype=dtype))
+
+ def testBroadcast(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.broadcast(x, (7, 42)),
+ args=(v,),
+ expected=np.tile(v, (7, 42, 1, 1)))
+
+ def testShiftRightLogical(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
+
+ def testShiftRightArithmetic(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([-1, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32))
+
+ PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT,
+ xla_data_pb2.PrecisionConfigProto.HIGH,
+ xla_data_pb2.PrecisionConfigProto.HIGHEST)
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testConv(self, precision):
+ for dtype in set(self.float_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ def conv_1d_fn(lhs, rhs):
+ dnums = xla_data_pb2.ConvolutionDimensionNumbers()
+ num_spatial_dims = 1
+ dnums.input_batch_dimension = 0
+ dnums.input_feature_dimension = 1
+ dnums.output_batch_dimension = 0
+ dnums.output_feature_dimension = 1
+ dnums.kernel_output_feature_dimension = 0
+ dnums.kernel_input_feature_dimension = 1
+ dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.conv(
+ lhs,
+ rhs,
+ window_strides=(1,),
+ padding=((2, 1),),
+ lhs_dilation=(1,),
+ rhs_dilation=(2,),
+ dimension_numbers=dnums)
+
+ self._assertOpOutputMatchesExpected(
+ conv_1d_fn,
+ args=(
+ np.array([[[3, 4, 5, 6]]], dtype=dtype),
+ np.array([[[-2, -3]]], dtype=dtype),
+ ),
+ expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testDotGeneral(self, precision):
+ for dtype in self.float_types:
+
+ def dot_fn(lhs, rhs):
+ dnums = xla_data_pb2.DotDimensionNumbers()
+ dnums.lhs_contracting_dimensions.append(2)
+ dnums.rhs_contracting_dimensions.append(1)
+ dnums.lhs_batch_dimensions.append(0)
+ dnums.rhs_batch_dimensions.append(0)
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.dot_general(
+ lhs,
+ rhs,
+ dimension_numbers=dnums,
+ precision_config=precision_config)
+
+ lhs = np.array(
+ [
+ [[1, 2], [3, 4]],
+ [[5, 6], [7, 8]],
+ ], dtype=dtype)
+ rhs = np.array(
+ [
+ [[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]],
+ ], dtype=dtype)
+ self._assertOpOutputMatchesExpected(
+ dot_fn,
+ args=(lhs, rhs),
+ expected=np.array(
+ [
+ [[9, 12, 15], [19, 26, 33]],
+ [[95, 106, 117], [129, 144, 159]],
+ ],
+ dtype=dtype))
+
+ def testNeg(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.neg,
+ args=(np.array([1, 2, 3], dtype=dtype),),
+ expected=np.array([-1, -2, -3], dtype=dtype))
+
+ def testPad(self):
+ for dtype in self.numeric_types:
+
+ def pad_fn(x):
+ return xla.pad(
+ x,
+ padding_value=7,
+ padding_low=[2, 1],
+ padding_high=[1, 2],
+ padding_interior=[1, 0])
+
+ self._assertOpOutputMatchesExpected(
+ pad_fn,
+ args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),),
+ expected=np.array(
+ [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7],
+ [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
+ dtype=dtype))
+
+ def testReduce(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def sum_reducer(x, y):
+ return x + y
+
+ def sum_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([12, 15, 18, 21], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([6, 22, 38], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0, 1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=dtype(66))
+
+ @function.Defun(dtype, dtype)
+ def mul_reducer(x, y):
+ return x * y
+
+ def mul_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ mul_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([0, 45, 120, 231], dtype=dtype))
+
+ def testSelectAndScatter(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def add_scatter(x, y):
+ return x + y
+
+ @function.Defun(dtype, dtype)
+ def ge_select(x, y):
+ return x >= y
+
+ def test_fn(operand, source):
+ return xla.select_and_scatter(
+ operand,
+ window_dimensions=[2, 3, 1, 1],
+ window_strides=[2, 2, 1, 1],
+ padding=[[0, 0]] * 4,
+ source=source,
+ init_value=0,
+ select=ge_select,
+ scatter=add_scatter)
+
+ self._assertOpOutputMatchesExpected(
+ test_fn,
+ args=(np.array(
+ [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6],
+ [0, 6, 2, 10, 2]],
+ dtype=dtype).reshape((4, 5, 1, 1)),
+ np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))),
+ expected=np.array(
+ [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0],
+ [0, 0, 0, 1, 0]],
+ dtype=dtype).reshape((4, 5, 1, 1)))
+
+ def testTranspose(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 8bbcff5f58..c1438f893f 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -101,6 +101,12 @@ tf_kernel_library(
"unary_ops.cc",
"unpack_op.cc",
"variable_ops.cc",
+ "xla_broadcast_helper_op.cc",
+ "xla_conv_op.cc",
+ "xla_dot_op.cc",
+ "xla_pad_op.cc",
+ "xla_reduce_op.cc",
+ "xla_select_and_scatter_op.cc",
],
hdrs = [
"index_ops.h",
@@ -110,6 +116,7 @@ tf_kernel_library(
":if_op",
":while_op",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index b11a4ce36d..8102faad28 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel {
explicit ReduceWindowOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_dimensions", &window_dimensions_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_strides", &window_strides_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_));
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const DataType dtype = context->input_type(0);
+ std::vector<int64> window_dimensions;
+ std::vector<int64> window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
const int rank = input_shape.dims();
- OP_REQUIRES(context, rank == window_dimensions_.size(),
+ OP_REQUIRES(context, rank == window_dimensions.size(),
errors::InvalidArgument(
"The size of window_dimensions must be equal to the input "
"rank (",
- window_dimensions_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == window_strides_.size(),
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
errors::InvalidArgument(
"The size of window_strides must be equal to the input "
"rank (",
- window_strides_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_low_.size(),
- errors::InvalidArgument(
- "The size of padding_low must be equal to the input "
- "rank (",
- padding_low_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_high_.size(),
- errors::InvalidArgument(
- "The size of padding_high must be equal to the input "
- "rank (",
- padding_high_.size(), " vs. ", rank, ")"));
-
- xla::XlaBuilder* builder = context->builder();
+ window_strides.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel {
compile_options.use_tuple_arg = false;
compile_options.resolve_compile_time_constants = false;
compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult reducer;
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
compile_options, *computation_,
@@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel {
xla::Shape scalar_shape;
OP_REQUIRES_OK(context,
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of ReduceWindow reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
OP_REQUIRES(context,
- xla::ShapeUtil::Compatible(
- reducer.xla_output_shape,
- xla::ShapeUtil::MakeTupleShape({scalar_shape})),
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
errors::InvalidArgument(
- "Invalid output shape of ReduceWindow reducer. Expected ",
- xla::ShapeUtil::HumanString(scalar_shape), " got ",
- xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
-
- // Wraps the reducer in a computation that unpacks the output tuple.
- xla::XlaComputation wrapper;
- {
- std::unique_ptr<xla::XlaBuilder> cb =
- builder->CreateSubBuilder("wrapper");
- auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x");
- auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y");
- auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y});
- xla::GetTupleElement(outputs, 0);
- xla::StatusOr<xla::XlaComputation> result = cb->Build();
- OP_REQUIRES_OK(context, result.status());
- wrapper = std::move(result.ValueOrDie());
- }
-
- std::vector<std::pair<int64, int64>> padding(rank);
- for (int i = 0; i < rank; ++i) {
- padding[i] = {padding_low_[i], padding_high_[i]};
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
}
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
- context->Input(0), context->Input(1), wrapper, window_dimensions_,
- window_strides_, padding);
+ context->Input(0), context->Input(1), *reducer.computation,
+ window_dimensions, window_strides, padding);
context->SetOutput(0, output);
}
private:
const NameAttrList* computation_;
- std::vector<int64> window_dimensions_;
- std::vector<int64> window_strides_;
- std::vector<int64> padding_low_;
- std::vector<int64> padding_high_;
TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp);
};
-REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp);
+REGISTER_XLA_OP(Name("XlaReduceWindow")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ ReduceWindowOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
new file mode 100644
index 0000000000..412afeaaad
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaBroadcastHelperOp : public XlaOpKernel {
+ public:
+ explicit XlaBroadcastHelperOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp lhs = context->Input(0);
+ xla::XlaOp rhs = context->Input(1);
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims();
+ const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape;
+ const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape;
+
+ std::vector<int64> broadcast_dims;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims",
+ &broadcast_dims));
+ if (broadcast_dims.empty()) {
+ OP_REQUIRES(
+ context,
+ lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 ||
+ rhs_shape.dims() == 0,
+ errors::InvalidArgument(
+ "If broadcast_dims is empty, both "
+ "arguments must have equal rank; "
+ "argument shapes, or at least one argument must be a scalar: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ return;
+ }
+
+ OP_REQUIRES(
+ context, broadcast_dims.size() == min_rank_shape->dims(),
+ errors::InvalidArgument(
+ "broadcast_dims must have size equal to the smaller argument rank; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ std::vector<int64> sorted_broadcast_dims = broadcast_dims;
+ absl::c_sort(sorted_broadcast_dims);
+ std::set<int64> dims_set(broadcast_dims.begin(), broadcast_dims.end());
+ OP_REQUIRES(context,
+ dims_set.size() == broadcast_dims.size() &&
+ broadcast_dims == sorted_broadcast_dims,
+ errors::InvalidArgument(
+ "Duplicate or nonmonotonic dimension in broadcast_dims; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]"));
+
+ std::vector<int64> broadcast_shape(max_rank_shape->dims(), 1LL);
+ for (int i = 0; i < broadcast_dims.size(); ++i) {
+ const int dim = broadcast_dims[i];
+ OP_REQUIRES(
+ context, dim >= 0 && dim < broadcast_shape.size(),
+ errors::InvalidArgument(
+ "Invalid broadcast dimension (", dim, "); broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ broadcast_shape[dim] = min_rank_shape->dim_size(i);
+ }
+ xla::PrimitiveType type = context->input_xla_type(0);
+ xla::Shape broadcast_xla_shape =
+ xla::ShapeUtil::MakeShape(type, broadcast_shape);
+ if (broadcast_lhs) {
+ lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims);
+ } else {
+ rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims);
+ }
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp);
+};
+
+REGISTER_XLA_OP(
+ Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"),
+ XlaBroadcastHelperOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
new file mode 100644
index 0000000000..8848623868
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaConvOp : public XlaOpKernel {
+ public:
+ explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+ const TensorShape padding_shape = context->InputShape("padding");
+ std::vector<int64> window_strides;
+ std::vector<int64> lhs_dilation;
+ std::vector<int64> rhs_dilation;
+ int64 feature_group_count;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation",
+ &lhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation",
+ &rhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(
+ "feature_group_count", &feature_group_count));
+
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
+ }
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::ConvGeneralDilated(
+ context->Input(0), context->Input(1), window_strides, padding,
+ lhs_dilation, rhs_dilation, dnums_, feature_group_count,
+ &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::ConvolutionDimensionNumbers dnums_;
+ xla::PrecisionConfigProto precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp);
+};
+
+REGISTER_XLA_OP(Name("XlaConv")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("lhs_dilation")
+ .CompileTimeConstInput("rhs_dilation")
+ .CompileTimeConstInput("feature_group_count")
+ .CompileTimeConstInput("padding"),
+ XlaConvOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
new file mode 100644
index 0000000000..2fed53e5c0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaDotOp : public XlaOpKernel {
+ public:
+ explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
+ dnums_, &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+ xla::PrecisionConfigProto precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
+};
+
+REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
new file mode 100644
index 0000000000..59502d83c7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaPadOp : public XlaOpKernel {
+ public:
+ explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape padding_value_shape =
+ context->InputShape("padding_value");
+
+ std::vector<int64> padding_low;
+ std::vector<int64> padding_high;
+ std::vector<int64> padding_interior;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low",
+ &padding_low));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high",
+ &padding_high));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "padding_interior", &padding_interior));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape),
+ errors::InvalidArgument("padding_value must be a scalar"));
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == padding_low.size(),
+ errors::InvalidArgument(
+ "The size of padding_low must be equal to the input "
+ "rank (",
+ padding_low.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_high.size(),
+ errors::InvalidArgument(
+ "The size of padding_high must be equal to the input "
+ "rank (",
+ padding_high.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_interior.size(),
+ errors::InvalidArgument(
+ "The size of padding_interior must be equal to the input "
+ "rank (",
+ padding_interior.size(), " vs. ", rank, ")"));
+
+ auto non_negative = [](int64 x) { return x >= 0; };
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_low, non_negative),
+ errors::InvalidArgument("padding_low must be non-negative, got [",
+ absl::StrJoin(padding_low, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_high, non_negative),
+ errors::InvalidArgument("padding_high must be non-negative, got [",
+ absl::StrJoin(padding_high, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_interior, non_negative),
+ errors::InvalidArgument("padding_interior must be non-negative, got [",
+ absl::StrJoin(padding_interior, ","), "]"));
+
+ xla::PaddingConfig padding_config;
+ for (int i = 0; i < rank; ++i) {
+ auto* dim = padding_config.add_dimensions();
+ dim->set_edge_padding_low(padding_low[i]);
+ dim->set_edge_padding_high(padding_high[i]);
+ dim->set_interior_padding(padding_interior[i]);
+ }
+
+ xla::XlaOp output =
+ xla::Pad(context->Input("input"), context->Input("padding_value"),
+ padding_config);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp);
+};
+
+REGISTER_XLA_OP(Name("XlaPad")
+ .CompileTimeConstInput("padding_low")
+ .CompileTimeConstInput("padding_high")
+ .CompileTimeConstInput("padding_interior"),
+ XlaPadOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
new file mode 100644
index 0000000000..fc2425f37b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaReduceOp : public XlaOpKernel {
+ public:
+ explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_));
+ OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce",
+ &dimensions_to_reduce_));
+ std::set<int64> dims_set(dimensions_to_reduce_.begin(),
+ dimensions_to_reduce_.end());
+ OP_REQUIRES(
+ context, dims_set.size() == dimensions_to_reduce_.size(),
+ errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
+ "argument to XlaReduce"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape init_value_shape = context->InputShape("init_value");
+ const DataType dtype = context->input_type(0);
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
+ errors::InvalidArgument("init_value must be a scalar"));
+
+ auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
+ OP_REQUIRES(context,
+ rank >= dimensions_to_reduce_.size() &&
+ absl::c_all_of(dimensions_to_reduce_, dim_in_range),
+ errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce"));
+
+ // Build the reducer function.
+ XlaCompiler::Argument reducer_arg;
+ reducer_arg.kind = XlaCompiler::Argument::kParameter;
+ reducer_arg.type = dtype;
+ reducer_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult reducer;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *reducer_,
+ {reducer_arg, reducer_arg}, &reducer));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaReduce reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ xla::XlaOp output =
+ xla::Reduce(context->Input("input"), context->Input("init_value"),
+ *reducer.computation, dimensions_to_reduce_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* reducer_;
+ std::vector<int64> dimensions_to_reduce_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
+};
+
+REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
new file mode 100644
index 0000000000..089776fcf7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
@@ -0,0 +1,147 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaSelectAndScatterOp : public XlaOpKernel {
+ public:
+ explicit XlaSelectAndScatterOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_));
+ OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ const DataType dtype = context->input_type(0);
+
+ std::vector<int64> window_dimensions;
+ std::vector<int64> window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == window_dimensions.size(),
+ errors::InvalidArgument(
+ "The size of window_dimensions must be equal to the input "
+ "rank (",
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
+ errors::InvalidArgument(
+ "The size of window_strides must be equal to the input "
+ "rank (",
+ window_strides.size(), " vs. ", rank, ")"));
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
+
+ // Build the select function.
+ XlaCompiler::Argument select_arg;
+ select_arg.kind = XlaCompiler::Argument::kParameter;
+ select_arg.type = dtype;
+ select_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult select;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *select_computation_,
+ {select_arg, select_arg}, &select));
+
+ xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {});
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(select.xla_output_shape,
+ select_output_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaSelectAndScatter select. Expected ",
+ xla::ShapeUtil::HumanString(select_output_shape), " got ",
+ xla::ShapeUtil::HumanString(select.xla_output_shape)));
+
+ // Build the scatter function.
+ XlaCompiler::Argument scatter_arg;
+ scatter_arg.kind = XlaCompiler::Argument::kParameter;
+ scatter_arg.type = dtype;
+ scatter_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult scatter;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *scatter_computation_,
+ {scatter_arg, scatter_arg}, &scatter));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of scatter. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(scatter.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
+ }
+
+ xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding(
+ context->Input("operand"), *select.computation, window_dimensions,
+ window_strides, padding, context->Input("source"),
+ context->Input("init_value"), *scatter.computation);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* select_computation_;
+ const NameAttrList* scatter_computation_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp);
+};
+
+REGISTER_XLA_OP(Name("XlaSelectAndScatter")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ XlaSelectAndScatterOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD
index ace6fd1d8e..4dce0a2102 100644
--- a/tensorflow/compiler/tf2xla/ops/BUILD
+++ b/tensorflow/compiler/tf2xla/ops/BUILD
@@ -11,6 +11,8 @@ cc_library(
srcs = ["xla_ops.cc"],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index a59c77f5c3..2cd9ae799f 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -13,11 +13,97 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace {
+
+// Helper shape function for operators that return an output with the same rank
+// as their first input.
+Status UnchangedRank(shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("XlaBroadcastHelper")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("broadcast_dims: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Output("lhs_output: T")
+ .Output("rhs_output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Helper operator for performing XLA-style broadcasts
+
+Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
+whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
+for binary operators.
+
+lhs: the LHS input tensor
+rhs: the RHS input tensor
+broadcast_dims: an XLA-style broadcast dimension specification
+lhs_output: the broadcasted LHS tensor
+rhs_output: the broadcasted RHS tensor
+)doc");
+
+REGISTER_OP("XlaConv")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("lhs_dilation: Tindices")
+ .Input("rhs_dilation: Tindices")
+ .Input("feature_group_count: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+.
+
+lhs: the input tensor
+rhs: the kernel tensor
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+lhs_dilation: dilation to apply between input elements
+rhs_dilation: dilation to apply between kernel elements
+feature_group_count: number of feature groups for grouped convolution.
+dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfigProto proto.
+)doc");
+
+REGISTER_OP("XlaDot")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Attr("T: numbertype")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
+.
+
+lhs: the LHS tensor
+rhs: the RHS tensor
+dimension_numbers: a serialized xla::DotDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfigProto proto.
+)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
@@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
+REGISTER_OP("XlaPad")
+ .Input("input: T")
+ .Input("padding_value: T")
+ .Input("padding_low: Tindices")
+ .Input("padding_high: Tindices")
+ .Input("padding_interior: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA Pad operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#pad
+.
+
+input: A `Tensor` of type T.
+padding_value: A scalar `Tensor` of type T.
+padding_low: the padding to apply at the start of each input dimensions
+padding_high: the padding to apply at the end of each input dimension.
+padding_interior: the padding to apply between each input element.
+output: A `Tensor` of type T.
+)doc");
+
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
@@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
+REGISTER_OP("XlaReduce")
+ .Input("input: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("dimensions_to_reduce: list(int)")
+ .Attr("reducer: func")
+ .Output("output: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ int rank = c->Rank(c->input(0));
+ std::vector<int64> dimensions_to_reduce;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
+ std::set<int64> dims_set(dimensions_to_reduce.begin(),
+ dimensions_to_reduce.end());
+ auto dim_in_range = [rank](int64 dim) {
+ return dim >= 0 && dim < rank;
+ };
+ if (rank < dimensions_to_reduce.size() ||
+ dims_set.size() != dimensions_to_reduce.size() ||
+ !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
+ return errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce");
+ }
+ c->set_output(
+ 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Wraps the XLA Reduce operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
+
+input: the input tensor
+init_value: a scalar representing the initial value for the reduction
+reducer: a reducer function to apply
+dimensions_to_reduce: dimension numbers over which to reduce
+)doc");
+
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
+ .Input("window_dimensions: Tindices")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
.Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
.Attr("computation: func")
- .Attr("window_dimensions: list(int)")
- .Attr("window_strides: list(int)")
- .Attr("padding_low: list(int)")
- .Attr("padding_high: list(int)")
.Output("output: T")
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
@@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
-padding_low: the padding to apply at the start of each input dimensions
-padding_high: the padding to apply at the end of each input dimension.
+padding: the padding to apply at the start and end of each input dimensions
+)doc");
+
+REGISTER_OP("XlaSelectAndScatter")
+ .Input("operand: T")
+ .Input("window_dimensions: Tindices")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("source: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("select: func")
+ .Attr("scatter: func")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA SelectAndScatter operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
+.
+
+operand: the input tensor
+window_dimensions: the shape of the window
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+source: a tensor of values to scatter
+init_value: a scalar representing the initial value for the output tensor
+select: a selection function to apply
+scatter: a scatter function to apply
)doc");
REGISTER_OP("XlaSend")
@@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD
index 42b6292f79..69ca394360 100644
--- a/tensorflow/compiler/tf2xla/python/BUILD
+++ b/tensorflow/compiler/tf2xla/python/BUILD
@@ -28,5 +28,6 @@ py_library(
srcs = ["xla.py"],
deps = [
"//tensorflow/compiler/tf2xla/ops:gen_xla_ops",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
],
)
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 2fc47dffb8..3626de375e 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -15,11 +15,12 @@
"""Experimental library that exposes XLA operations directly in TensorFlow.
It is sometimes useful to be able to build HLO programs directly from
-TensorFlow. This file provides Tensorflow operators that map as closely as
-possible to HLO operators.
+TensorFlow. This file provides Tensorflow operators that mirror the semantics of
+HLO operators as closely as possible.
-There is no promise of backward or forward compatibility for operators defined
-in this module.
+Note: There is no promise of backward or forward compatibility for operators
+defined in this module. This is primarily because the underlying HLO operators
+do not promise backward or forward compatibility.
"""
from __future__ import absolute_import
@@ -27,11 +28,298 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
+# ops include:
+# infeed/outfeed (available via tf.contrib.tpu)
+# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
+# conditional
+# gather/scatter
+# collapse
+
+# This file reuses builtin names (following XLA's names, so we can call things
+# like xla.max), so we capture the builtin versions here.
+# pylint: disable=redefined-builtin
+_max = max
+_min = min
+_slice = slice # pylint: disable=invalid-name
+
+constant = constant_op.constant
+
+# Unary operators.
+
+# For most arithmetic operators there is a TensorFlow operator
+# that exactly corresponds to each XLA operator. Rather than defining
+# XLA-specific variants, we reuse the corresponding TensorFlow operator.
+# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
+# wrap every HLO operator, because that would allow us to be confident that the
+# semantics match.
+
+
+def _unary_op(fn):
+ """Wrapper that restricts `fn` to have the correct signature."""
+
+ def unary_op_wrapper(x, name=None):
+ return fn(x, name=name)
+
+ return unary_op_wrapper
+
+
+abs = _unary_op(math_ops.abs)
+# TODO(phawkins): implement clz.
+conj = _unary_op(math_ops.conj)
+cos = _unary_op(math_ops.cos)
+ceil = _unary_op(math_ops.ceil)
+digamma = _unary_op(math_ops.digamma)
+erf = _unary_op(math_ops.erf)
+erfc = _unary_op(math_ops.erfc)
+# TODO(phawkins): implement erfinv
+exp = _unary_op(math_ops.exp)
+expm1 = _unary_op(math_ops.expm1)
+floor = _unary_op(math_ops.floor)
+imag = _unary_op(math_ops.imag)
+is_finite = _unary_op(math_ops.is_finite)
+lgamma = _unary_op(math_ops.lgamma)
+log = _unary_op(math_ops.log)
+log1p = _unary_op(math_ops.log1p)
+logical_not = _unary_op(math_ops.logical_not)
+neg = _unary_op(math_ops.neg)
+real = _unary_op(math_ops.real)
+# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
+# numbers halfway between two integers.
+round = _unary_op(math_ops.round)
+sin = _unary_op(math_ops.sin)
+sign = _unary_op(math_ops.sign)
+tanh = _unary_op(math_ops.tanh)
+
+# Binary operators
+
+# The main difference between TensorFlow and XLA binary ops is the broadcasting
+# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
+# requires an explicit specification of which dimensions to broadcast if the
+# arguments have different ranks.
+
+
+def _broadcasting_binary_op(fn):
+ """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
+
+ def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
+ """Inner wrapper function."""
+ broadcast_dims = broadcast_dims or []
+ broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
+ # Rather than relying on having static shape information in the TensorFlow
+ # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
+ # at JIT compilation time.
+ x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
+ return fn(x, y, name=name)
+
+ return broadcasting_binary_op_wrapper
+
+
+# Map from TF signed types to TF unsigned types.
+_SIGNED_TO_UNSIGNED_TABLE = {
+ dtypes.int8: dtypes.uint8,
+ dtypes.int16: dtypes.uint16,
+ dtypes.int32: dtypes.uint32,
+ dtypes.int64: dtypes.uint64,
+}
+
+# Map from TF unsigned types to TF signed types.
+_UNSIGNED_TO_SIGNED_TABLE = {
+ dtypes.uint8: dtypes.int8,
+ dtypes.uint16: dtypes.int16,
+ dtypes.uint32: dtypes.int32,
+ dtypes.uint64: dtypes.int64,
+}
+
+
+def _shift_right_logical_helper(x, y, name=None):
+ """Performs an integer right logical shift irrespective of input type."""
+ assert y.dtype == x.dtype
+ dtype = x.dtype
+ signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
+ if signed:
+ unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
+ x = math_ops.cast(x, unsigned_dtype)
+ y = math_ops.cast(y, unsigned_dtype)
+ output = bitwise_ops.right_shift(x, y, name=name)
+ if signed:
+ output = math_ops.cast(output, dtype)
+ return output
+
+
+def _shift_right_arithmetic_helper(x, y, name=None):
+ """Performs an integer right arithmetic shift irrespective of input type."""
+ assert y.dtype == x.dtype
+ dtype = x.dtype
+ unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
+ if unsigned:
+ signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
+ x = math_ops.cast(x, signed_dtype)
+ y = math_ops.cast(y, signed_dtype)
+ output = bitwise_ops.right_shift(x, y, name=name)
+ if unsigned:
+ output = math_ops.cast(output, dtype)
+ return output
+
+
+add = _broadcasting_binary_op(math_ops.add)
+sub = _broadcasting_binary_op(math_ops.sub)
+mul = _broadcasting_binary_op(math_ops.mul)
+div = _broadcasting_binary_op(math_ops.div)
+rem = _broadcasting_binary_op(gen_math_ops.mod)
+max = _broadcasting_binary_op(math_ops.maximum)
+min = _broadcasting_binary_op(math_ops.minimum)
+atan2 = _broadcasting_binary_op(math_ops.atan2)
+complex = _broadcasting_binary_op(math_ops.complex)
+logical_and = _broadcasting_binary_op(math_ops.logical_and)
+logical_or = _broadcasting_binary_op(math_ops.logical_or)
+logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
+eq = _broadcasting_binary_op(math_ops.equal)
+ne = _broadcasting_binary_op(math_ops.not_equal)
+ge = _broadcasting_binary_op(math_ops.greater_equal)
+gt = _broadcasting_binary_op(math_ops.greater)
+le = _broadcasting_binary_op(math_ops.less_equal)
+lt = _broadcasting_binary_op(math_ops.less)
+pow = _broadcasting_binary_op(math_ops.pow)
+shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
+shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
+shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
+
+
+def _binary_op(fn):
+ """Wrapper that restricts `fn` to have the correct signature."""
+
+ def binary_op_wrapper(x, y, name=None):
+ return fn(x, y, name=name)
+
+ return binary_op_wrapper
+
+
+transpose = _binary_op(array_ops.transpose)
+rev = _binary_op(array_ops.reverse)
+
+bitcast_convert_type = array_ops.bitcast
+
+
+def broadcast(x, dims, name=None):
+ x = ops.convert_to_tensor(x)
+ shape = array_ops.concat(
+ [constant_op.constant(dims),
+ array_ops.shape(x)], axis=0)
+ return array_ops.broadcast_to(x, shape, name=name)
+
+
+def clamp(a, x, b, name=None):
+ return min(max(a, x, name=name), b, name=name)
+
+
+concatenate = array_ops.concat
+
+
+def conv(lhs,
+ rhs,
+ window_strides,
+ padding,
+ lhs_dilation,
+ rhs_dilation,
+ dimension_numbers,
+ feature_group_count=1,
+ precision_config=None,
+ name=None):
+ """Wraps the XLA ConvGeneralDilated operator.
+
+ ConvGeneralDilated is the most general form of XLA convolution and is
+ documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+
+ Args:
+ lhs: the input tensor
+ rhs: the kernel tensor
+ window_strides: the inter-window strides
+ padding: the padding to apply at the start and end of each input dimensions
+ lhs_dilation: dilation to apply between input elements
+ rhs_dilation: dilation to apply between kernel elements
+ dimension_numbers: a `ConvolutionDimensionNumbers` proto.
+ feature_group_count: number of feature groups for grouped convolution.
+ precision_config: a `PrecisionConfigProto` proto.
+ name: an optional name for the operator
+
+ Returns:
+ A tensor representing the output of the convolution.
+ """
+ precision_config_proto = ""
+ if precision_config:
+ precision_config_proto = precision_config.SerializeToString()
+ return gen_xla_ops.xla_conv(
+ lhs,
+ rhs,
+ window_strides=window_strides,
+ padding=padding,
+ lhs_dilation=lhs_dilation,
+ rhs_dilation=rhs_dilation,
+ feature_group_count=feature_group_count,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ precision_config=precision_config_proto,
+ name=name)
+
+
+convert_element_type = math_ops.cast
+
+
+def dot(lhs, rhs, name=None):
+ return math_ops.tensordot(lhs, rhs, axes=1, name=name)
+
+
+def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
+ precision_config_proto = ""
+ if precision_config:
+ precision_config_proto = precision_config.SerializeToString()
+ return gen_xla_ops.xla_dot(
+ lhs,
+ rhs,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ precision_config=precision_config_proto,
+ name=name)
+
+
+def dynamic_slice(x, starts, sizes, name=None):
+ # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
+ # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
+ # slice if the slice is out of bounds.
+ return array_ops.slice(x, starts, sizes, name=name)
-# TODO(phawkins): provide wrappers for all XLA operators.
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
+# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
+# the XLA-specific pad operator.
+pad = gen_xla_ops.xla_pad
+
+
+def random_normal(mu, sigma, dims, name=None):
+ mu = ops.convert_to_tensor(mu)
+ return random_ops.random_normal(
+ dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
+
+
+def random_uniform(minval, maxval, dims, name=None):
+ minval = ops.convert_to_tensor(minval)
+ return random_ops.random_uniform(
+ dims, minval, maxval, dtype=minval.dtype, name=name)
+
+
+recv = gen_xla_ops.xla_recv
+reduce = gen_xla_ops.xla_reduce
+
def reduce_window(operand,
init,
@@ -61,22 +349,38 @@ def reduce_window(operand,
"""
window_strides = window_strides or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
- padding_low = [x for (x, _) in padding]
- padding_high = [y for (_, y) in padding]
return gen_xla_ops.xla_reduce_window(
- operand,
- init,
- reducer,
- window_dimensions,
- window_strides,
- padding_low,
- padding_high,
+ input=operand,
+ init_value=init,
+ window_dimensions=window_dimensions,
+ window_strides=window_strides,
+ padding=padding,
+ computation=reducer,
name=name)
-recv = gen_xla_ops.xla_recv
+def reshape(x, new_sizes, dimensions=None, name=None):
+ if dimensions is not None:
+ x = array_ops.transpose(x, dimensions)
+ x = array_ops.reshape(x, new_sizes, name=name)
+ return x
+
+
+def select(condition, x, y, name=None):
+ return array_ops.where(condition, x, y, name)
+
+
+select_and_scatter = gen_xla_ops.xla_select_and_scatter
send = gen_xla_ops.xla_send
-sort = gen_xla_ops.xla_sort
+def slice(x, start_dims, limit_dims, strides):
+ spec = [
+ _slice(start, limit, stride)
+ for (start, limit, stride) in zip(start_dims, limit_dims, strides)
+ ]
+ return x[tuple(spec)]
+
+
+sort = gen_xla_ops.xla_sort
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 31a41f8719..9e8f5f2a1a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -99,6 +99,25 @@ Status XlaOpKernelContext::ConstantInput(int index,
index, context_->input(index).shape().dim_sizes(), constant_literal);
}
+static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
+ StringPiece name) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was "
+ "expected");
+ }
+ return start;
+}
+
+Status XlaOpKernelContext::ConstantInput(StringPiece name,
+ xla::Literal* constant_literal) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInput(index, constant_literal);
+}
+
Status XlaOpKernelContext::ConstantInputReshaped(
int index, gtl::ArraySlice<int64> new_dims,
xla::Literal* constant_literal) {
@@ -246,6 +265,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
return LiteralToInt64Scalar(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name,
+ int64* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsIntScalar(index, out);
+}
+
Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
@@ -280,6 +305,12 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
return LiteralToInt64Vector(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name,
+ std::vector<int64>* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsIntVector(index, out);
+}
+
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
int index, std::vector<int64>* out) {
xla::Literal literal;
@@ -313,6 +344,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
}
}
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name,
+ xla::Literal* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsInt64Literal(index, out);
+}
+
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 3f21a2bf41..3e26ba4f01 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -106,6 +106,7 @@ class XlaOpKernelContext {
// expression cannot be evaluated, e.g., because it depends on unbound
// parameters, returns a non-OK status.
Status ConstantInput(int index, xla::Literal* constant_literal);
+ Status ConstantInput(StringPiece name, xla::Literal* constant_literal);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
@@ -117,12 +118,14 @@ class XlaOpKernelContext {
// Converts a constant scalar int32 or int64 tensor into an int64.
Status ConstantInputAsIntScalar(int index, int64* out);
+ Status ConstantInputAsIntScalar(StringPiece name, int64* out);
// Converts a constant scalar float32 or float64 tensor into a float64.
Status ConstantInputAsFloatScalar(int index, double* out);
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+ Status ConstantInputAsIntVector(StringPiece name, std::vector<int64>* out);
// Reshapes and converts a constant int32 or int64 tensor into a vector of
// int64s.
@@ -130,6 +133,7 @@ class XlaOpKernelContext {
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
+ Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out);
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape);