aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-02 12:46:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 12:50:33 -0700
commitde4c12857782f65dc4a941776d506ecac50a5934 (patch)
treef7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler/xla/service/shape_inference_test.cc
parentdebcc45d2dca24a6914fc823477e5a1a43be3028 (diff)
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder. PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc102
1 files changed, 96 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 511d2c22f8..a73fa181cd 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest {
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
auto inferred_status = ShapeInference::InferReduceShape(
- arg, f32_, dimensions_to_reduce, to_apply);
+ {&arg, &f32_}, dimensions_to_reduce, to_apply);
EXPECT_IS_OK(inferred_status.status());
EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
inferred_status.ValueOrDie()));
@@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
/*dimensions_to_reduce=*/{0, 1, 2});
}
+TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_IS_OK(inferred_status.status());
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply =
+ ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
+ ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr(
+ "parameter shape differs from the result shape: s32[] vs f32[]"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must have at least 2 arguments, has 0"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply =
+ ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("accumulator shape at index 0 differs from the "
+ "init_value shape: s32[] vs f32[]"));
+}
+
TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status = ShapeInference::InferReduceShape(
- ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4},
- to_apply);
+ {&arg_shape, &f32_},
+ /*dimensions_to_reduce=*/{3, 4}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
HasSubstr("out-of-bounds dimension"));
@@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status =
- ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ ShapeInference::InferReduceShape({&arg_shape, &f32_},
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
@@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status =
- ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ ShapeInference::InferReduceShape({&arg_shape, &f32_},
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
- HasSubstr("first parameter shape differs"));
+ HasSubstr("0-th parameter shape differs"));
}
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {