aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-26 12:16:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 12:20:19 -0700
commitb29b839215fa9bf5a00ca97e19673cfa5f780314 (patch)
tree26355f17565a9e49bf5e107493b1de51a32f25a6 /tensorflow/compiler/xla/service/shape_inference_test.cc
parentf97fd78f7ef585215d13b39980319b8cad13ddd3 (diff)
[XLA] Map API change to enable mapping over an arbitrary set of dimensions.
PiperOrigin-RevId: 170090055
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc31
1 files changed, 16 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 8c731ae297..7c9c7e8d6a 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -505,7 +505,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
Shape arg = ShapeUtil::MakeShape(F32, {20});
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
- auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply);
+ auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
EXPECT_IS_OK(inferred_status.status());
Shape expected = ShapeUtil::MakeShape(S32, {20});
EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
@@ -514,91 +514,92 @@ TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
TEST_F(ShapeInferenceTest, Map) {
auto inferred_status_r1f32 = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
EXPECT_IS_OK(inferred_status_r1f32.status());
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
// It's OK to provide a single argument, as long as the applied arity matches
// (this degenerates to a Map).
auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
- {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
EXPECT_IS_OK(inferred_status_r1f32_one.status());
EXPECT_TRUE(
ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
auto inferred_status_r2s32 = ShapeInference::InferMapShape(
{&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
- ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_));
+ ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
EXPECT_IS_OK(inferred_status_r2s32.status());
EXPECT_TRUE(
ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
auto no_args_error = ShapeInference::InferMapShape(
- {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
ASSERT_FALSE(no_args_error.ok());
ASSERT_THAT(no_args_error.status().error_message(),
HasSubstr("expects at least one argument"));
auto args_diff_shapes_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_64_},
- ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
ASSERT_FALSE(args_diff_shapes_error.ok());
ASSERT_THAT(args_diff_shapes_error.status().error_message(),
HasSubstr("requires all operands to have the same shape"));
auto arity_error = ShapeInference::InferMapShape(
- {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
+ {0});
ASSERT_FALSE(arity_error.ok());
ASSERT_THAT(arity_error.status().error_message(),
HasSubstr("function arity must match"));
auto output_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_));
+ ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
ASSERT_FALSE(output_shape_error.ok());
ASSERT_THAT(output_shape_error.status().error_message(),
HasSubstr("result has to be a scalar"));
auto param_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_));
+ ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
ASSERT_FALSE(param_shape_error.ok());
ASSERT_THAT(param_shape_error.status().error_message(),
HasSubstr("parameter has to be a scalar"));
auto param_element_type_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({f32_, s32_}, f32_));
+ ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
ASSERT_FALSE(param_element_type_error.ok());
ASSERT_THAT(param_element_type_error.status().error_message(),
HasSubstr("parameter type has to match argument"));
Shape arg = ShapeUtil::MakeShape(F32, {20});
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
- auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply);
+ auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
EXPECT_IS_OK(inferred_status.status());
EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
auto inferred_status_error1 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("arity must match number of arguments"));
auto inferred_status_error2 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_));
+ {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("has to be a scalar"));
auto inferred_status_error3 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_));
+ {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("has to be a scalar"));
auto inferred_status_error5 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_));
+ {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(inferred_status_error5.status().error_message(),
HasSubstr("parameter type has to match argument"));