aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc54
1 files changed, 50 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 55772ca324..246b386f38 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -45,8 +45,6 @@ namespace tensorflow {
class XlaCompilerTest : public ::testing::Test {
protected:
- XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
-
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
@@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test {
XlaCompiler::Options DefaultOptions() {
XlaCompiler::Options options;
- options.device_type = &cpu_device_type_;
+ options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client_;
options.flib_def = flib_def_.get();
return options;
@@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test {
return compiler->local_flib_def_.get();
}
- DeviceType cpu_device_type_;
xla::Client* client_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
@@ -979,5 +976,54 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests a graph which has a function with an invalid op.
+TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
+ XlaCompiler compiler(DefaultOptions());
+
+ FunctionDefLibrary flib;
+ FunctionDef fn = FillFn();
+ NodeDef* node = fn.add_node_def();
+ node->set_name("Invalid");
+ node->set_op("InvalidOp"); /* unsupported op */
+ node = fn.add_node_def();
+ node->set_name("Switch");
+ node->set_op("Switch"); /* control flow node */
+ *flib.add_function() = fn;
+
+ TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
+ auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
+ TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ NodeDef def;
+ TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
+ .Input(value.name(), 0, DT_INT32)
+ .Input(shape.name(), 1, DT_INT32)
+ .Finalize(&def));
+ Status status;
+ Node* fill = scope.graph()->AddNode(def, &status);
+ TF_ASSERT_OK(status);
+ TF_ASSERT_OK(scope.DoShapeInference(fill));
+ scope.graph()->AddEdge(value.node(), 0, fill, 0);
+ scope.graph()->AddEdge(shape.node(), 0, fill, 1);
+
+ auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
+
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
+ std::move(graph), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow