aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-10-04 14:59:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 15:04:44 -0700
commit2e2e89699c1186eef157911b57e4b062de376ce9 (patch)
tree53f729e3fe75b375e32d9e17b634532872a7ea33 /tensorflow/compiler/tf2xla
parenta742575879db1df48daf929b8d29e43a1d168dd7 (diff)
Add basic TensorList op support in bridge.
* Add kernels for TensorListReserve. EmptyTensorList, TensorListElementShape, TensorListPushBack, TensorlistPopBack; * Treat list type pretty much identical to Stack in the bridge for now; * Support variant output by treating variant like a uint8 and leaving the interpretation up to the XlaExpression (variant type does not support tensor_data()); PiperOrigin-RevId: 215809335
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc226
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc40
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
4 files changed, 263 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 9a7130f253..95a010a119 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -95,6 +95,7 @@ tf_kernel_library(
"stateless_random_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
+ "tensor_list_ops.cc",
"tile_ops.cc",
"topk_op.cc",
"training_ops.cc",
@@ -158,6 +159,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:pooling_ops",
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
new file mode 100644
index 0000000000..74d4fcc425
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -0,0 +1,226 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA TensorList operators.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
+ TensorShape* tensor_list_shape) {
+ auto shape_or_status = builder->GetShape(op);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ xla::Shape shape = shape_or_status.ValueOrDie();
+ TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
+ return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
+ tensor_list_shape);
+}
+
+class TensorListReserveOp : public XlaOpKernel {
+ public:
+ explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape element_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
+ int64 num_elements;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
+
+ TensorShape tensor_shape;
+ tensor_shape.AddDim(num_elements);
+ tensor_shape.AppendShape(element_shape);
+
+ xla::XlaBuilder* b = ctx->builder();
+ ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
+ tensor_shape.dim_sizes()),
+ xla::ConstantR0<int32>(b, 0)}));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListReserve")
+ .CompileTimeConstInput("element_shape")
+ .CompileTimeConstInput("num_elements"),
+ TensorListReserveOp);
+
+class EmptyTensorListOp : public XlaOpKernel {
+ public:
+ explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->CtxFailure(
+ errors::InvalidArgument("XLA compilation requires a fixed tensor list "
+ "size. Use TensorListReserve instead."));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
+};
+
+REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp);
+
+class TensorListElementShapeOp : public XlaOpKernel {
+ public:
+ explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
+ shape.RemoveDim(0);
+
+ switch (shape_type_) {
+ case DT_INT64:
+ ctx->SetOutput(0, xla::ConstantR1<int64>(b, shape.dim_sizes()));
+ break;
+ case DT_INT32: {
+ std::vector<int32> size;
+ for (int64 s : shape.dim_sizes()) {
+ size.push_back(s);
+ }
+ ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
+ break;
+ }
+ default:
+ ctx->CtxFailure(
+ errors::InvalidArgument("Unsupported shape type requested"));
+ return;
+ }
+ }
+
+ private:
+ DataType shape_type_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp);
+
+class TensorListPushBackOp : public XlaOpKernel {
+ public:
+ explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp list = ctx->Input(0);
+ TensorShape elem_shape = ctx->InputShape(1);
+
+ xla::XlaOp ta = xla::GetTupleElement(list, 0);
+ xla::XlaOp index = xla::GetTupleElement(list, 1);
+ xla::XlaOp value = ctx->Input(1);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto start_indices =
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+
+ TensorShape slice_shape = elem_shape;
+ slice_shape.InsertDim(0, 1LL);
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
+
+ // TODO(phawkins): We don't check the index is in bounds --- there is no
+ // error mechanism in XLA.
+ ctx->SetOutput(
+ 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
+ index + xla::ConstantR0<int32>(b, 1)}));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp);
+
+class TensorListPopBackOp : public XlaOpKernel {
+ public:
+ explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp state = ctx->Input(0);
+
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
+
+ xla::XlaOp ta = xla::GetTupleElement(state, 0);
+ xla::XlaOp index = xla::GetTupleElement(state, 1);
+
+ index = index - xla::ConstantR0<int32>(b, 1);
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices =
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}}));
+
+ auto slice_shape = shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ // TODO(phawkins): We don't check the index is in bounds --- there is no
+ // error mechanism in XLA.
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
+ // Remove the leading '1' dimension.
+ std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
+
+ ctx->SetOutput(0, xla::Tuple(b, {ta, index}));
+ ctx->SetOutput(1, xla::Reshape(read, value_shape));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 2a9eaeee14..dd3498ef7a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
+Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
+ Tensor** output) {
+ // The step's default allocator is the dummy XlaCompilationAllocator which
+ // simply allocates a metadata buffer to hold the expression to which it
+ // corresponds.
+ if (expected_output_dtype(index) == DT_VARIANT) {
+ // tensor_data() is not supported for variant Tensor (i.e.,
+ // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
+ // XlaExpression inside the Tensor's tensor_data() does not work for
+ // variant. Instead construct a uint8 tensor and store the expression in its
+ // value.
+ // TODO(jpienaar): This should be refactored to stop masquerading
+ // XlaExpressions as Tensors.
+ *output = new Tensor();
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(
+ context_->allocate_temp(DT_UINT8, tensor_shape, *output));
+ context_->set_output(index, **output);
+ } else {
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
+ TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
+ }
+ return Status::OK();
+}
+
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
- auto shape = builder()->GetShape(handle);
- if (!shape.ok()) {
- SetStatus(shape.status());
+ auto shape_or = builder()->GetShape(handle);
+ if (!shape_or.ok()) {
+ SetStatus(shape_or.status());
return;
}
- // The step's default allocator is the dummy XlaCompilationAllocator which
- // simply allocates a metadata buffer to hold the expression to which it
- // corresponds.
- TensorShape tensor_shape;
- OP_REQUIRES_OK(context_,
- XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
OP_REQUIRES_OK(context_,
- context_->allocate_output(index, tensor_shape, &output));
+ allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index a3a0d10cc0..aa00a45496 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -255,6 +255,11 @@ class XlaOpKernelContext {
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);
+ // Wraps OpKernelContext's allocate_output method while providing special
+ // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
+ // type to allow mapping for variant to more generic types.
+ Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
+
OpKernelContext* const context_;
};