diff options
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 381c0205fd..2e362e0a63 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -138,7 +138,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, UnsupportedTypes) { +TEST(XlaCompilationTest, Complex128Unsupported) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -158,6 +158,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, HalfSupported) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Tensor t(DT_HALF, TensorShape()); + t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_HALF) + .WithAttr("value", t)); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_FALSE(clusters.empty()); +} + TEST(XlaCompilationTest, ConcatWithConstArg) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); GraphDef graphdef; |