From 2e2e89699c1186eef157911b57e4b062de376ce9 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 4 Oct 2018 14:59:43 -0700 Subject: Add basic TensorList op support in bridge. * Add kernels for TensorListReserve. EmptyTensorList, TensorListElementShape, TensorListPushBack, TensorlistPopBack; * Treat list type pretty much identical to Stack in the bridge for now; * Support variant output by treating variant like a uint8 and leaving the interpretation up to the XlaExpression (variant type does not support tensor_data()); PiperOrigin-RevId: 215809335 --- tensorflow/compiler/tf2xla/kernels/BUILD | 2 + .../compiler/tf2xla/kernels/tensor_list_ops.cc | 226 +++++++++++++++++++++ tensorflow/compiler/tf2xla/xla_op_kernel.cc | 40 +++- tensorflow/compiler/tf2xla/xla_op_kernel.h | 5 + 4 files changed, 263 insertions(+), 10 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc (limited to 'tensorflow/compiler/tf2xla') diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 9a7130f253..95a010a119 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -95,6 +95,7 @@ tf_kernel_library( "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", + "tensor_list_ops.cc", "tile_ops.cc", "topk_op.cc", "training_ops.cc", @@ -158,6 +159,7 @@ tf_kernel_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:pooling_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc new file mode 100644 index 0000000000..74d4fcc425 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorList operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, + TensorShape* tensor_list_shape) { + auto shape_or_status = builder->GetShape(op); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + tensor_list_shape); +} + +class TensorListReserveOp : public XlaOpKernel { + public: + explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + + TensorShape tensor_shape; + tensor_shape.AddDim(num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, 0)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp); +}; + +REGISTER_XLA_OP(Name("TensorListReserve") + .CompileTimeConstInput("element_shape") + .CompileTimeConstInput("num_elements"), + TensorListReserveOp); + +class EmptyTensorListOp : public XlaOpKernel { + public: + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + ctx->CtxFailure( + errors::InvalidArgument("XLA compilation requires a fixed tensor list " + "size. Use TensorListReserve instead.")); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); +}; + +REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); + +class TensorListElementShapeOp : public XlaOpKernel { + public: + explicit TensorListElementShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + shape.RemoveDim(0); + + switch (shape_type_) { + case DT_INT64: + ctx->SetOutput(0, xla::ConstantR1(b, shape.dim_sizes())); + break; + case DT_INT32: { + std::vector size; + for (int64 s : shape.dim_sizes()) { + size.push_back(s); + } + ctx->SetOutput(0, xla::ConstantR1(b, size)); + break; + } + default: + ctx->CtxFailure( + errors::InvalidArgument("Unsupported shape type requested")); + return; + } + } + + private: + DataType shape_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp); +}; + +REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); + +class TensorListPushBackOp : public XlaOpKernel { + public: + explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp list = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(1); + + xla::XlaOp ta = xla::GetTupleElement(list, 0); + xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + ctx->SetOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp); + +class TensorListPopBackOp : public XlaOpKernel { + public: + explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); + + index = index - xla::ConstantR0(b, 1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); + + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetOutput(1, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2a9eaeee14..dd3498ef7a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } +Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, + Tensor** output) { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + if (expected_output_dtype(index) == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in its + // value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + *output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, *output)); + context_->set_output(index, **output); + } else { + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); + } + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; - auto shape = builder()->GetShape(handle); - if (!shape.ok()) { - SetStatus(shape.status()); + auto shape_or = builder()->GetShape(handle); + if (!shape_or.ok()) { + SetStatus(shape_or.status()); return; } - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - TensorShape tensor_shape; - OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, - context_->allocate_output(index, tensor_shape, &output)); + allocate_output(index, shape_or.ValueOrDie(), &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a3a0d10cc0..aa00a45496 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -255,6 +255,11 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); + // Wraps OpKernelContext's allocate_output method while providing special + // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the + // type to allow mapping for variant to more generic types. + Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + OpKernelContext* const context_; }; -- cgit v1.2.3