diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 72 |
1 files changed, 37 insertions, 35 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0c98c20805..678e209cf6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -231,10 +232,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape = - is_entry_computation - ? options_.shape_representation_fn(arg.shape, arg.type) - : arg.shape; + TensorShape shape; + if (is_entry_computation) { + TF_ASSIGN_OR_RETURN( + shape, options_.shape_representation_fn(arg.shape, arg.type)); + } else { + shape = arg.shape; + } return TensorShapeToXLAShape(arg.type, shape, xla_shape); } case XlaCompiler::Argument::kResource: { @@ -242,8 +246,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TensorShape representation_shape = - options_.shape_representation_fn(arg.shape, arg.type); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + options_.shape_representation_fn(arg.shape, arg.type)); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -418,16 +423,18 @@ Status BuildComputation( // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); } } *num_computation_outputs = elems.size(); - // Builds the XLA computation. - if (always_return_tuple || elems.size() != 1) { - xla::Tuple(builder, elems); + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); } builder->ClearOpMetadata(); @@ -664,20 +671,17 @@ Status XlaCompiler::CompileSingleOp( namespace { // Check that the ops of all non-functional nodes have been registered. -string ValidateFunctionDef(const FunctionDef* fdef, +Status ValidateFunctionDef(const FunctionDef* fdef, const FunctionLibraryDefinition& flib_def) { - std::vector<string> invalid_ops; for (const NodeDef& node : fdef->node_def()) { const string& op = node.op(); if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { continue; } const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { - invalid_ops.push_back(op); - } + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); } - return tensorflow::str_util::Join(invalid_ops, ", "); + return Status::OK(); } // Check that the graph doesn't have any invalid nodes (e.g. incompatible with @@ -685,35 +689,33 @@ string ValidateFunctionDef(const FunctionDef* fdef, Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, const DeviceType& device_type, const string& name) { - std::set<string> invalid_ops; + auto maybe_error = [&](const string& op, const Status& s) -> Status { + if (!s.ok()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ": ", op, " (", s.error_message(), + ")")); + } + return Status::OK(); + }; + for (const Node* node : graph->nodes()) { if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } const FunctionDef* fdef = flib_def.Find(node->def().op()); + Status s; if (fdef) { - string error_msg = ValidateFunctionDef(fdef, flib_def); - if (!error_msg.empty()) { - invalid_ops.insert( - strings::StrCat(node->def().op(), ":{", error_msg, "}")); - } + s = ValidateFunctionDef(fdef, flib_def); + TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); continue; } const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { - invalid_ops.insert(node->def().op()); - continue; - } + s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); + TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); - if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { - invalid_ops.insert(node->def().op()); - } - } - if (!invalid_ops.empty()) { - return errors::InvalidArgument(strings::StrCat( - "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ":", - tensorflow::str_util::Join(invalid_ops, ", "))); + s = FindKernelDef(device_type, node->def(), nullptr, nullptr); + TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); } return Status::OK(); } |