diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-08-02 12:46:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 12:50:33 -0700 |
commit | de4c12857782f65dc4a941776d506ecac50a5934 (patch) | |
tree | f7685195a99d20db045c2ccb50f5cc66f605b8b3 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | debcc45d2dca24a6914fc823477e5a1a43be3028 (diff) |
[XLA] Introduce variadic version of reduce.
This defines the semantics, and adds parser and shape inference support. Since support is not plumbed through the rest of the compiler here, multi-output reduce is still rejected by the HLO verifier, and is not exposed through XlaBuilder.
PiperOrigin-RevId: 207148035
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 102 |
1 files changed, 96 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 511d2c22f8..a73fa181cd 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( - arg, f32_, dimensions_to_reduce, to_apply); + {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); @@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { /*dimensions_to_reduce=*/{0, 1, 2}); } +TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), + inferred_status.ValueOrDie())); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, + ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr( + "parameter shape differs from the result shape: s32[] vs f32[]")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must have at least 2 arguments, has 0")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("accumulator shape at index 0 differs from the " + "init_value shape: s32[] vs f32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( - ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, - to_apply); + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); @@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), @@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - HasSubstr("first parameter shape differs")); + HasSubstr("0-th parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { |