aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc23
1 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 381c0205fd..2e362e0a63 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -138,7 +138,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
EXPECT_EQ(clusters["A"], clusters["C"]);
}
-TEST(XlaCompilationTest, UnsupportedTypes) {
+TEST(XlaCompilationTest, Complex128Unsupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
@@ -158,6 +158,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
EXPECT_TRUE(clusters.empty());
}
+TEST(XlaCompilationTest, HalfSupported) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Tensor t(DT_HALF, TensorShape());
+ t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_HALF)
+ .WithAttr("value", t));
+ Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
+ ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ auto clusters = GetClusters(*graph);
+ EXPECT_FALSE(clusters.empty());
+}
+
TEST(XlaCompilationTest, ConcatWithConstArg) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;