aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_op_kernel.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc97
1 files changed, 81 insertions, 16 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index c2298b97e1..38ec559576 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -19,8 +19,11 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
@@ -64,10 +67,32 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
}
+const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
+ return GetComputationFromTensor(GetInputTensorByName(name));
+}
+
TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
+TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
+ return GetInputTensorByName(name).shape();
+}
+
+DataType XlaOpKernelContext::input_type(int index) const {
+ return context_->input(index).dtype();
+}
+
+xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
+ xla::PrimitiveType type;
+ Status status = DataTypeToPrimitiveType(input_type(index), &type);
+ if (!status.ok()) {
+ SetStatus(status);
+ return xla::PRIMITIVE_TYPE_INVALID;
+ }
+ return type;
+}
+
Status XlaOpKernelContext::ConstantInput(int index,
xla::Literal* constant_literal) {
return ConstantInputReshaped(
@@ -316,10 +341,11 @@ Status XlaOpKernelContext::ConstantInputList(
return Status::OK();
}
-Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
- TensorShape* shape,
- xla::XlaOp* value) {
- const Tensor& tensor = context_->input(index);
+namespace {
+
+Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, TensorShape* shape,
+ xla::XlaOp* value) {
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
@@ -337,9 +363,10 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
*shape = variable->shape();
}
- XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(variable->shape(), variable->type());
+ XlaContext& xla_context = XlaContext::Get(ctx);
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ xla_context.RepresentationShape(variable->shape(), variable->type()));
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
@@ -348,6 +375,22 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
return Status::OK();
}
+} // namespace
+
+Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(context_->input(index), type, context_, shape,
+ value);
+}
+
+Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
+ shape, value);
+}
+
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
@@ -438,17 +481,17 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
return Status::OK();
}
-Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
- xla::XlaOp handle) {
- TF_RET_CHECK(handle.valid());
+namespace {
- const XlaExpression* expression =
- CastExpressionFromTensor(context_->input(input_index));
+Status AssignVariableTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, xla::XlaOp handle,
+ xla::XlaBuilder* builder) {
+ const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
- auto shape_or_status = builder()->GetShape(handle);
+ auto shape_or_status = builder->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -458,15 +501,31 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
- XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(shape, type);
+ XlaContext& xla_context = XlaContext::Get(ctx);
+ TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
+ xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) {
handle = xla::Reshape(handle, representation_shape.dim_sizes());
}
return variable->SetValue(handle);
}
+} // namespace
+
+Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(context_->input(input_index), type, context_,
+ handle, builder());
+}
+
+Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(GetInputTensorByName(name), type, context_,
+ handle, builder());
+}
+
XlaCompiler* XlaOpKernelContext::compiler() const {
return XlaContext::Get(context_).compiler();
}
@@ -506,6 +565,12 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
return XlaContext::Get(context_).GetOrCreateMul(type);
}
+const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
+ const Tensor* tensor;
+ CHECK(context_->input(name, &tensor).ok());
+ return *tensor;
+}
+
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {