aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-29 10:54:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 10:58:50 -0700
commit920ede367cc07a126820059ec165525687291bea (patch)
tree597353366e1b1d8b2ccef6a884030db1b440f335
parentec9cd6f150fd74b0727c6418c69d225702631ce2 (diff)
[TF/XLA] Add validation to find ops incompatible with the given device type at the beginning of graph compilation.
PiperOrigin-RevId: 198421828
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc65
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h7
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc54
6 files changed, 117 insertions, 17 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 27287e0f96..902fe27acd 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
XlaCompiler::Options options;
options.client = client;
- options.device_type = &cache->device_type();
+ options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index ab644ff5a6..b1943d3e1a 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile(
core::ScopedUnref cache_ref(cache);
XlaCompiler::Options options;
- DeviceType device_type = metadata.jit_device_type();
- options.device_type = &device_type;
+ options.device_type = metadata.jit_device_type();
options.client = metadata.client();
options.flib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 3a08aa8cf4..ac768b206e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
- DeviceType device_type(DEVICE_CPU_XLA_JIT);
- compiler_options.device_type = &device_type;
+ compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index f7098917b1..ccbc74eb31 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(options),
initialization_status_(Status::OK()),
next_step_id_(1),
- device_(
- new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
+ device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {
- // We no longer need the device_type.
- options_.device_type = nullptr;
-
+ CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
@@ -659,6 +656,59 @@ Status XlaCompiler::CompileSingleOp(
return CompileGraph(options, name, std::move(graph), args, result);
}
+namespace {
+
+// Check that the ops of all non-functional nodes have been registered.
+string ValidateFunctionDef(const FunctionDef* fdef,
+ const FunctionLibraryDefinition& flib_def) {
+ std::vector<string> invalid_ops;
+ for (const NodeDef& node : fdef->node_def()) {
+ const string& op = node.op();
+ if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
+ continue;
+ }
+ const OpDef* op_def;
+ if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
+ invalid_ops.push_back(op);
+ }
+ }
+ return tensorflow::str_util::Join(invalid_ops, ", ");
+}
+
+// Check that the graph doesn't have any nodes incompatible with given
+// device_type.
+Status ValidateGraph(const Graph* graph,
+ const FunctionLibraryDefinition& flib_def,
+ const DeviceType& device_type, const string& name) {
+ std::vector<string> invalid_ops;
+ for (const Node* node : graph->nodes()) {
+ if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
+ continue;
+ }
+ const FunctionDef* fdef = flib_def.Find(node->def().op());
+ if (fdef) {
+ string error_msg = ValidateFunctionDef(fdef, flib_def);
+ if (!error_msg.empty()) {
+ invalid_ops.push_back(
+ strings::StrCat(node->def().op(), ":{", error_msg, "}"));
+ }
+ continue;
+ }
+ if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
+ invalid_ops.push_back(node->def().op());
+ }
+ }
+ if (!invalid_ops.empty()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ":",
+ tensorflow::str_util::Join(invalid_ops, ", ")));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
string const& name,
std::unique_ptr<Graph> graph,
@@ -681,6 +731,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
graph.get(), local_flib_def_.get()));
+ // Detect ops incompatible with the device_type.
+ // FunctionalizeControlFlow may remove some unsupported ops.
+ TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
+ options_.device_type, name));
+
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(
this, &builder, options_.allow_cpu_custom_calls,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index bf496bd8bc..76f4c4c1ea 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -244,9 +245,9 @@ class XlaCompiler {
typedef std::function<TensorShape(const TensorShape&, DataType)>
ShapeRepresentationFn;
struct Options {
- // Name of the compilation device to use. Needs to be live only during
- // XlaCompiler's constructor.
- const DeviceType* device_type = nullptr;
+ // Name of the compilation device to use. It must be set by the caller.
+ // The default empty value is invalid.
+ DeviceType device_type = DeviceType("");
xla::Client* client = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 55772ca324..246b386f38 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -45,8 +45,6 @@ namespace tensorflow {
class XlaCompilerTest : public ::testing::Test {
protected:
- XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
-
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
@@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test {
XlaCompiler::Options DefaultOptions() {
XlaCompiler::Options options;
- options.device_type = &cpu_device_type_;
+ options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client_;
options.flib_def = flib_def_.get();
return options;
@@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test {
return compiler->local_flib_def_.get();
}
- DeviceType cpu_device_type_;
xla::Client* client_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
@@ -979,5 +976,54 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests a graph which has a function with an invalid op.
+TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
+ XlaCompiler compiler(DefaultOptions());
+
+ FunctionDefLibrary flib;
+ FunctionDef fn = FillFn();
+ NodeDef* node = fn.add_node_def();
+ node->set_name("Invalid");
+ node->set_op("InvalidOp"); /* unsupported op */
+ node = fn.add_node_def();
+ node->set_name("Switch");
+ node->set_op("Switch"); /* control flow node */
+ *flib.add_function() = fn;
+
+ TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
+ auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
+ TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ NodeDef def;
+ TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
+ .Input(value.name(), 0, DT_INT32)
+ .Input(shape.name(), 1, DT_INT32)
+ .Finalize(&def));
+ Status status;
+ Node* fill = scope.graph()->AddNode(def, &status);
+ TF_ASSERT_OK(status);
+ TF_ASSERT_OK(scope.DoShapeInference(fill));
+ scope.graph()->AddEdge(value.node(), 0, fill, 0);
+ scope.graph()->AddEdge(shape.node(), 0, fill, 1);
+
+ auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
+
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
+ std::move(graph), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow