aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_op_kernel.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc24
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 636cb71e21..2a9eaeee14 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const {
return context_->input(index).dtype();
}
+DataType XlaOpKernelContext::InputType(absl::string_view name) {
+ return GetInputTensorByName(name).dtype();
+}
+
xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(input_type(index), &type);
@@ -102,8 +106,7 @@ Status XlaOpKernelContext::ConstantInput(int index,
static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
absl::string_view name) {
int start, stop;
- TF_RETURN_IF_ERROR(context->op_kernel().InputRange(
- StringPiece(name.data(), name.length()), &start, &stop));
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
@@ -214,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
}
- xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
- compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
- &layout);
+ xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
+ constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
- "as a compile-time constant.\nError: ",
+ " as a compile-time constant.\nError: ",
computed.status().error_message());
}
- *constant_literal = std::move(*computed.ValueOrDie());
+ *constant_literal = std::move(computed).ValueOrDie();
return Status::OK();
}
@@ -366,8 +368,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
- TF_RETURN_IF_ERROR(
- context_->input_list(StringPiece(name.data(), name.size()), &inputs));
+ TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
@@ -380,8 +381,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
Status XlaOpKernelContext::ConstantInputList(
absl::string_view name, std::vector<xla::Literal>* outputs) {
int start, stop;
- TF_RETURN_IF_ERROR(op_kernel().InputRange(
- StringPiece(name.data(), name.size()), &start, &stop));
+ TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
@@ -615,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
- CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok());
+ CHECK(context_->input(name, &tensor).ok());
return *tensor;
}