aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc65
1 files changed, 60 insertions, 5 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index f7098917b1..ccbc74eb31 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(options),
initialization_status_(Status::OK()),
next_step_id_(1),
- device_(
- new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
+ device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {
- // We no longer need the device_type.
- options_.device_type = nullptr;
-
+ CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
@@ -659,6 +656,59 @@ Status XlaCompiler::CompileSingleOp(
return CompileGraph(options, name, std::move(graph), args, result);
}
+namespace {
+
+// Check that the ops of all non-functional nodes have been registered.
+string ValidateFunctionDef(const FunctionDef* fdef,
+ const FunctionLibraryDefinition& flib_def) {
+ std::vector<string> invalid_ops;
+ for (const NodeDef& node : fdef->node_def()) {
+ const string& op = node.op();
+ if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
+ continue;
+ }
+ const OpDef* op_def;
+ if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
+ invalid_ops.push_back(op);
+ }
+ }
+ return tensorflow::str_util::Join(invalid_ops, ", ");
+}
+
+// Check that the graph doesn't have any nodes incompatible with given
+// device_type.
+Status ValidateGraph(const Graph* graph,
+ const FunctionLibraryDefinition& flib_def,
+ const DeviceType& device_type, const string& name) {
+ std::vector<string> invalid_ops;
+ for (const Node* node : graph->nodes()) {
+ if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
+ continue;
+ }
+ const FunctionDef* fdef = flib_def.Find(node->def().op());
+ if (fdef) {
+ string error_msg = ValidateFunctionDef(fdef, flib_def);
+ if (!error_msg.empty()) {
+ invalid_ops.push_back(
+ strings::StrCat(node->def().op(), ":{", error_msg, "}"));
+ }
+ continue;
+ }
+ if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
+ invalid_ops.push_back(node->def().op());
+ }
+ }
+ if (!invalid_ops.empty()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ":",
+ tensorflow::str_util::Join(invalid_ops, ", ")));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
string const& name,
std::unique_ptr<Graph> graph,
@@ -681,6 +731,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
graph.get(), local_flib_def_.get()));
+ // Detect ops incompatible with the device_type.
+ // FunctionalizeControlFlow may remove some unsupported ops.
+ TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
+ options_.device_type, name));
+
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(
this, &builder, options_.allow_cpu_custom_calls,