aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-31 09:11:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-31 10:24:41 -0700
commitbbd2047cf3a715a1431889ad8f558576a5382876 (patch)
treeb13e1fc5453146bb151d14d30c62181323963aa2 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent50be7aa7d72ded57c11c705e9de80da2bdc2220b (diff)
[XLA:HLO] Minor fix for Clamp shape inference, and add some tests.
Previously Clamp(f32[5], f32[], f32[9]) returned success, but it now returns a failure. Noticed while debugging a different problem. Change: 151835981
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc93
1 files changed, 93 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 5a1ae6b002..6f968ded56 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -157,6 +157,99 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
testing::ContainsRegex("pred operand must have PRED element type"));
}
+TEST_F(ShapeInferenceTest, ClampAllMatrix) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
+ matrix_64_48_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampAllScalar) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampMinScalar) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampMaxScalar) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampOperandScalar) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampMinMatrix) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, ClampBadShapes) {
+ // Type mismatch
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
+ .ok());
+ // Dimension mismatch
+ ASSERT_FALSE(
+ ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
+ vector_64_, vector_32_, vector_32_)
+ .ok());
+ ASSERT_FALSE(
+ ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
+ vector_32_, vector_64_, vector_32_)
+ .ok());
+ ASSERT_FALSE(
+ ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
+ vector_32_, vector_32_, vector_64_)
+ .ok());
+ // Dimension mismatch, where one operand is a scalar
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
+ .ok());
+}
+
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});