aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc72
1 files changed, 37 insertions, 35 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 0c98c20805..678e209cf6 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -231,10 +232,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
case XlaCompiler::Argument::kParameter: {
- TensorShape shape =
- is_entry_computation
- ? options_.shape_representation_fn(arg.shape, arg.type)
- : arg.shape;
+ TensorShape shape;
+ if (is_entry_computation) {
+ TF_ASSIGN_OR_RETURN(
+ shape, options_.shape_representation_fn(arg.shape, arg.type));
+ } else {
+ shape = arg.shape;
+ }
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
}
case XlaCompiler::Argument::kResource: {
@@ -242,8 +246,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
switch (arg.resource_kind) {
case XlaResource::kVariable: {
- TensorShape representation_shape =
- options_.shape_representation_fn(arg.shape, arg.type);
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ options_.shape_representation_fn(arg.shape, arg.type));
return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape);
}
@@ -418,16 +423,18 @@ Status BuildComputation(
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
-
elems.push_back(handle);
}
}
*num_computation_outputs = elems.size();
- // Builds the XLA computation.
- if (always_return_tuple || elems.size() != 1) {
- xla::Tuple(builder, elems);
+ // Builds the XLA computation. We *always* form a tuple here to ensure that
+ // the output value is the last thing added into the XLA computation, even
+ // if there is only one output value.
+ auto tuple = xla::Tuple(builder, elems);
+ if (!always_return_tuple && elems.size() == 1) {
+ xla::GetTupleElement(tuple, 0);
}
builder->ClearOpMetadata();
@@ -664,20 +671,17 @@ Status XlaCompiler::CompileSingleOp(
namespace {
// Check that the ops of all non-functional nodes have been registered.
-string ValidateFunctionDef(const FunctionDef* fdef,
+Status ValidateFunctionDef(const FunctionDef* fdef,
const FunctionLibraryDefinition& flib_def) {
- std::vector<string> invalid_ops;
for (const NodeDef& node : fdef->node_def()) {
const string& op = node.op();
if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
continue;
}
const OpDef* op_def;
- if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
- invalid_ops.push_back(op);
- }
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
}
- return tensorflow::str_util::Join(invalid_ops, ", ");
+ return Status::OK();
}
// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
@@ -685,35 +689,33 @@ string ValidateFunctionDef(const FunctionDef* fdef,
Status ValidateGraph(const Graph* graph,
const FunctionLibraryDefinition& flib_def,
const DeviceType& device_type, const string& name) {
- std::set<string> invalid_ops;
+ auto maybe_error = [&](const string& op, const Status& s) -> Status {
+ if (!s.ok()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ": ", op, " (", s.error_message(),
+ ")"));
+ }
+ return Status::OK();
+ };
+
for (const Node* node : graph->nodes()) {
if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
continue;
}
const FunctionDef* fdef = flib_def.Find(node->def().op());
+ Status s;
if (fdef) {
- string error_msg = ValidateFunctionDef(fdef, flib_def);
- if (!error_msg.empty()) {
- invalid_ops.insert(
- strings::StrCat(node->def().op(), ":{", error_msg, "}"));
- }
+ s = ValidateFunctionDef(fdef, flib_def);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
continue;
}
const OpDef* op_def;
- if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) {
- invalid_ops.insert(node->def().op());
- continue;
- }
+ s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
- if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
- invalid_ops.insert(node->def().op());
- }
- }
- if (!invalid_ops.empty()) {
- return errors::InvalidArgument(strings::StrCat(
- "Detected unsupported operations when trying to compile graph ", name,
- " on ", device_type.type_string(), ":",
- tensorflow::str_util::Join(invalid_ops, ", ")));
+ s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
}
return Status::OK();
}