diff options
author | 2017-09-26 12:16:20 -0700 | |
---|---|---|
committer | 2017-09-26 12:20:19 -0700 | |
commit | b29b839215fa9bf5a00ca97e19673cfa5f780314 (patch) | |
tree | 26355f17565a9e49bf5e107493b1de51a32f25a6 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | f97fd78f7ef585215d13b39980319b8cad13ddd3 (diff) |
[XLA] Map API change to enable mapping over an arbitrary set of dimensions.
PiperOrigin-RevId: 170090055
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 31 |
1 files changed, 16 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 8c731ae297..7c9c7e8d6a 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -505,7 +505,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { TEST_F(ShapeInferenceTest, MapThatChangesElementType) { Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); Shape expected = ShapeUtil::MakeShape(S32, {20}); EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); @@ -514,91 +514,92 @@ TEST_F(ShapeInferenceTest, MapThatChangesElementType) { TEST_F(ShapeInferenceTest, Map) { auto inferred_status_r1f32 = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie())); // It's OK to provide a single argument, as long as the applied arity matches // (this degenerates to a Map). auto inferred_status_r1f32_one = ShapeInference::InferMapShape( - {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32_one.status()); EXPECT_TRUE( ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie())); auto inferred_status_r2s32 = ShapeInference::InferMapShape( {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, - ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_)); + ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); EXPECT_IS_OK(inferred_status_r2s32.status()); EXPECT_TRUE( ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie())); auto no_args_error = ShapeInference::InferMapShape( - {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); ASSERT_FALSE(no_args_error.ok()); ASSERT_THAT(no_args_error.status().error_message(), HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(args_diff_shapes_error.ok()); ASSERT_THAT(args_diff_shapes_error.status().error_message(), HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( - {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), + {0}); ASSERT_FALSE(arity_error.ok()); ASSERT_THAT(arity_error.status().error_message(), HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0}); ASSERT_FALSE(output_shape_error.ok()); ASSERT_THAT(output_shape_error.status().error_message(), HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0}); ASSERT_FALSE(param_shape_error.ok()); ASSERT_THAT(param_shape_error.status().error_message(), HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0}); ASSERT_FALSE(param_element_type_error.ok()); ASSERT_THAT(param_element_type_error.status().error_message(), HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie())); auto inferred_status_error1 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); + {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); + {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), HasSubstr("parameter type has to match argument")); |