diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 636cb71e21..2a9eaeee14 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const { return context_->input(index).dtype(); } +DataType XlaOpKernelContext::InputType(absl::string_view name) { + return GetInputTensorByName(name).dtype(); +} + xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; Status status = DataTypeToPrimitiveType(input_type(index), &type); @@ -102,8 +106,7 @@ Status XlaOpKernelContext::ConstantInput(int index, static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context, absl::string_view name) { int start, stop; - TF_RETURN_IF_ERROR(context->op_kernel().InputRange( - StringPiece(name.data(), name.length()), &start, &stop)); + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, @@ -214,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr<std::unique_ptr<xla::Literal>> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, - "as a compile-time constant.\nError: ", + " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } @@ -366,8 +368,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, std::vector<TensorShape>* shapes) { OpInputList inputs; - TF_RETURN_IF_ERROR( - context_->input_list(StringPiece(name.data(), name.size()), &inputs)); + TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { @@ -380,8 +381,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, Status XlaOpKernelContext::ConstantInputList( absl::string_view name, std::vector<xla::Literal>* outputs) { int start, stop; - TF_RETURN_IF_ERROR(op_kernel().InputRange( - StringPiece(name.data(), name.size()), &start, &stop)); + TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); @@ -615,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; - CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok()); + CHECK(context_->input(name, &tensor).ok()); return *tensor; } |