aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-07 17:46:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 17:50:07 -0800
commit0e9cc7f3113ade82436729bd541f6b501d023ac0 (patch)
tree797d2a0867bba92008d93d9f6cc416bb3b9f8e57 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent1667d4dcd2c7c33a3bcade62014931a1f8d9a2e0 (diff)
[XLA] Implement Conditional in XLA service, client ComputationBuilder, and CPU backend.
PiperOrigin-RevId: 178322445
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc75
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6e53d2d609..7af2805f12 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1437,5 +1437,80 @@ TEST_F(ShapeInferenceTest, Transpose) {
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
}
+TEST_F(ShapeInferenceTest, Conditional) {
+ auto inferred_status0 = ShapeInference::InferConditionalShape(
+ pred_, vector_32_, vector_64_,
+ ShapeUtil::MakeProgramShape({vector_32_}, f32_),
+ ShapeUtil::MakeProgramShape({vector_64_}, f32_));
+ EXPECT_IS_OK(inferred_status0.status());
+ EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
+
+ auto inferred_status1 = ShapeInference::InferConditionalShape(
+ pred_, matrix_32_48_, vector_32_,
+ ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
+ ShapeUtil::MakeProgramShape({vector_32_}, vector_64_));
+ EXPECT_IS_OK(inferred_status1.status());
+ EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
+
+ auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
+ auto inferred_status2 = ShapeInference::InferConditionalShape(
+ pred_, matrix_32_48_, tuple_f32_v32,
+ ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
+ ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_));
+ EXPECT_IS_OK(inferred_status2.status());
+ EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
+
+ auto inferred_status_error0 = ShapeInference::InferConditionalShape(
+ s32_, vector_32_, vector_64_,
+ ShapeUtil::MakeProgramShape({vector_32_}, f32_),
+ ShapeUtil::MakeProgramShape({vector_64_}, f32_));
+ EXPECT_FALSE(inferred_status_error0.ok());
+ EXPECT_THAT(inferred_status_error0.status().error_message(),
+ HasSubstr("predicate must be a boolean"));
+
+ auto inferred_status_error1 = ShapeInference::InferConditionalShape(
+ pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_,
+ ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
+ ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_));
+ EXPECT_FALSE(inferred_status_error1.ok());
+ EXPECT_THAT(inferred_status_error1.status().error_message(),
+ HasSubstr("true_computation must take 1 argument"));
+
+ auto inferred_status_error2 = ShapeInference::InferConditionalShape(
+ pred_, vector_32_, vector_64_,
+ ShapeUtil::MakeProgramShape({vector_64_}, f32_),
+ ShapeUtil::MakeProgramShape({vector_64_}, f32_));
+ EXPECT_FALSE(inferred_status_error2.ok());
+ EXPECT_THAT(inferred_status_error2.status().error_message(),
+ HasSubstr("true_operand must match the shape of the only "
+ "parameter of true_computation"));
+
+ auto inferred_status_error3 = ShapeInference::InferConditionalShape(
+ pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
+ ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
+ ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_));
+ EXPECT_FALSE(inferred_status_error3.ok());
+ EXPECT_THAT(inferred_status_error3.status().error_message(),
+ HasSubstr("false_computation must take 1 argument"));
+
+ auto inferred_status_error4 = ShapeInference::InferConditionalShape(
+ pred_, vector_32_, vector_64_,
+ ShapeUtil::MakeProgramShape({vector_32_}, f32_),
+ ShapeUtil::MakeProgramShape({vector_32_}, f32_));
+ EXPECT_FALSE(inferred_status_error4.ok());
+ EXPECT_THAT(inferred_status_error4.status().error_message(),
+ HasSubstr("false_operand must match the shape of the only "
+ "parameter of false_computation"));
+
+ auto inferred_status_error5 = ShapeInference::InferConditionalShape(
+ pred_, vector_32_, vector_64_,
+ ShapeUtil::MakeProgramShape({vector_32_}, f32_),
+ ShapeUtil::MakeProgramShape({vector_64_}, vector_32_));
+ EXPECT_FALSE(inferred_status_error5.ok());
+ EXPECT_THAT(inferred_status_error5.status().error_message(),
+ HasSubstr("the result of true_computation and false_computation "
+ "must have the same shape"));
+}
+
} // namespace
} // namespace xla