diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc new file mode 100644 index 0000000000..3883b907b4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -0,0 +1,253 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" + +#include <numeric> + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" + +namespace tensorflow { + +XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context) + : context_(context) {} + +bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { + return context_->ValidateInputsAreSameShape(op); +} + +xla::ComputationBuilder* XlaOpKernelContext::builder() const { + return &XlaContext::Get(this).builder(); +} + +const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { + return XlaContext::GetComputationFromTensor(context_->input(index)); +} + +TensorShape XlaOpKernelContext::InputShape(int index) { + return context_->input(index).shape(); +} + +Status XlaOpKernelContext::ConstantInput(int index, + xla::Literal* constant_literal) { + return ConstantInputReshaped( + index, context_->input(index).shape().dim_sizes(), constant_literal); +} + +Status XlaOpKernelContext::ConstantInputReshaped( + int index, gtl::ArraySlice<int64> new_dims, + xla::Literal* constant_literal) { + const Tensor& tensor = context_->input(index); + TensorShape new_shape(new_dims); + if (tensor.NumElements() != new_shape.num_elements()) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + tensor.shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + new_shape.DebugString()); + } + const XlaExpression* expression = + XlaContext::CastExpressionFromTensor(tensor); + + // If the tensor has a known constant value, there is no need to invoke XLA. + if (expression->has_constant_value()) { + Tensor temp(tensor.dtype()); + if (!temp.CopyFrom(expression->constant_value(), new_shape)) { + // This should never happen. The constant should have a shape compatible + // with the enclosing Tensor. + return errors::Internal("Incompatible shapes in ConstantInputReshaped."); + } + return HostTensorToLiteral(temp, constant_literal); + } + + // Make sure we treat zero-element tensors as constant. + if (new_shape.num_elements() == 0) { + Tensor temp(tensor.dtype(), new_shape); + return HostTensorToLiteral(temp, constant_literal); + } + + xla::ComputationDataHandle handle = expression->handle(); + if (new_shape != tensor.shape()) { + // Reshape the handle to the desired shape. + handle = builder()->Reshape(handle, new_shape.dim_sizes()); + } + + // The XLA layout is specified minor to major, and TensorFlow's minor + // dimension is the last one. + std::vector<int64> layout_indices(new_shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + + // Ask the XLA compiler to evaluate the data handle to a literal. + xla::StatusOr<std::unique_ptr<xla::GlobalData>> computed = + builder()->ComputeConstant(handle, &layout); + if (!computed.ok()) { + return errors::InvalidArgument( + "Error evaluating ", context_->op_kernel().name(), " input ", index, + ": ", computed.status().error_message()); + } + // Fetch the literal from the compiler service. + xla::StatusOr<std::unique_ptr<xla::Literal>> constant = + builder()->client()->Transfer(*computed.ValueOrDie()); + if (!constant.ok()) { + return errors::InvalidArgument( + "Error evaluating ", context_->op_kernel().name(), " input ", index, + ": ", constant.status().error_message()); + } + constant_literal->Swap(constant.ValueOrDie().get()); + return Status::OK(); +} + +// Converts an int32 or int64 1D literal to an int64 vector. +static Status LiteralToInt64Vector(const xla::Literal& literal, + std::vector<int64>* out) { + if (xla::ShapeUtil::Rank(literal.shape()) != 1) { + return errors::InvalidArgument("value is not 1D"); + } + int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); + if (literal.shape().element_type() == xla::S32) { + for (int64 i = 0; i < size; ++i) { + out->push_back(xla::LiteralUtil::Get<int32>(literal, {i})); + } + } else if (literal.shape().element_type() == xla::S64) { + for (int64 i = 0; i < size; ++i) { + out->push_back(xla::LiteralUtil::Get<int64>(literal, {i})); + } + } else { + return errors::InvalidArgument("value must be either int32 or int64"); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ConstantInputAsIntVector(int index, + std::vector<int64>* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + return LiteralToInt64Vector(literal, out); +} + +// TODO(phawkins): validate that the dimensions form a valid shape, fail +// gracefully if they do not. +Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + std::vector<int64> dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = TensorShape(dims); + return Status::OK(); +} + +Status XlaOpKernelContext::InputList( + StringPiece name, std::vector<xla::ComputationDataHandle>* handles, + std::vector<TensorShape>* shapes) { + OpInputList inputs; + TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); + handles->clear(); + shapes->clear(); + for (const Tensor& input : inputs) { + handles->push_back(XlaContext::GetComputationFromTensor(input)); + shapes->push_back(input.shape()); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ConstantInputList( + StringPiece name, std::vector<xla::Literal>* outputs) { + int start, stop; + TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); + outputs->resize(stop - start); + for (int i = start; i < stop; ++i) { + TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); + } + return Status::OK(); +} + +void XlaOpKernelContext::SetOutput(int index, + const xla::ComputationDataHandle& handle) { + // Makes the host Tensor that will refer to the expression. + Tensor* output = nullptr; + auto shape = builder()->GetShape(handle); + if (!shape.ok()) { + SetStatus(shape.status()); + return; + } + + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + OP_REQUIRES_OK( + context_, + context_->allocate_output( + index, XLAShapeToTensorShape(*shape.ValueOrDie()), &output)); + + // The expression is stored in the tensor's data buffer. Fill in the + // fields now. + XlaExpression* expression = + XlaContext::CastExpressionFromUninitializedTensor(output); + expression->set_handle(handle); +} + +void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { + const TensorShape& shape = constant.shape(); + + xla::Literal literal; + OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); + xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); + + // Make the Tensor that will refer to the expression. + Tensor* output = nullptr; + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); + + // The expression is stored in the tensor's data buffer. Fill in the + // fields now. + XlaExpression* expression = + XlaContext::CastExpressionFromUninitializedTensor(output); + expression->set_handle(handle); + expression->set_constant_value(constant); +} + +void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } +void XlaOpKernelContext::CtxFailureWithWarning(Status s) { + context_->CtxFailureWithWarning(s); +} + +const xla::Computation* XlaOpKernelContext::GetOrCreateMax( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMax(type); +} + +const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateAdd(type); +} + +const xla::Computation* XlaOpKernelContext::GetOrCreateSigmoid( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateSigmoid(type); +} + +XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} + +void XlaOpKernel::Compute(OpKernelContext* context) { + XlaOpKernelContext xla_context(context); + Compile(&xla_context); +} + +} // namespace tensorflow |