diff options
17 files changed, 105 insertions, 64 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 210a4d95b9..a80412e951 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1307,6 +1307,7 @@ StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant( ComputationDataHandle ComputationBuilder::Map( tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, const Computation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions, tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) { if (!first_error_.ok() || !PrepareComputation().ok()) { return ComputationDataHandle(); @@ -1317,6 +1318,9 @@ ComputationDataHandle ComputationBuilder::Map( *request.add_operands() = operand; } *request.mutable_to_apply() = computation.handle(); + for (int64 dimension : dimensions) { + request.add_dimensions(dimension); + } for (const ComputationDataHandle& sop : static_operands) { *request.add_static_operands() = sop; } diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index b0e6720be2..73972c1290 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -604,6 +604,7 @@ class ComputationBuilder { ComputationDataHandle Map( tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, const Computation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions, tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 0a288a77ad..0eaa21ef25 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -169,7 +169,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { TEST_F(HloCostAnalysisTest, Map) { ComputationBuilder builder(client_, "map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); - auto result = builder.Map({input}, add_and_exp_); + auto result = builder.Map({input}, add_and_exp_, {0}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -286,7 +286,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); // sigmoid(input * weight + bias) auto result = builder.Map( - {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_); + {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 2405d44778..c16747c02c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -241,12 +241,20 @@ class ShapeVerifier : public DfsHloVisitor { HloComputation* function, tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) override { std::vector<const Shape*> operand_shapes; + int64 max_operand_rank = 0; for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); } + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + std::vector<int64> map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); return CheckShape( - map, ShapeInference::InferMapShape( - operand_shapes, map->to_apply()->ComputeProgramShape())); + map, + ShapeInference::InferMapShape( + operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)); } Status HandleReduceWindow(HloInstruction* reduce_window, diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 5178a750b9..23c8266e77 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -852,7 +852,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr<Shape> ShapeInference::InferMapShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply) { + const ProgramShape& to_apply, + tensorflow::gtl::ArraySlice<int64> dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument"); } @@ -888,6 +889,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( tensorflow::str_util::Join(pieces, ", ").c_str()); } + // Check that dimensions.size == arg_shape.dimensions_size() (we currently + // only support mapping across all dimensions: i.e. scalar map functions). + if (dimensions.size() != arg_shape->dimensions_size()) { + return InvalidArgument( + "Map applied to a subset of dimensions currently not supported: " + "arg_dimension_size: %d, requested_map_dimensions_size: %zu", + arg_shape->dimensions_size(), dimensions.size()); + } + + // Check that requested map dimensions numbers are monotonically increasing. + for (int i = 0; i < dimensions.size(); ++i) { + if (dimensions[i] != i) { + return InvalidArgument( + "Map requires monotonically increasing dimension numbers, found: %s ", + tensorflow::str_util::Join(dimensions, ", ").c_str()); + } + } + // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 379feef5e4..d5d497176d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -78,7 +78,8 @@ class ShapeInference { // to the given operand shapes. static StatusOr<Shape> InferMapShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply); + const ProgramShape& to_apply, + tensorflow::gtl::ArraySlice<int64> dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 8c731ae297..7c9c7e8d6a 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -505,7 +505,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { TEST_F(ShapeInferenceTest, MapThatChangesElementType) { Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); Shape expected = ShapeUtil::MakeShape(S32, {20}); EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); @@ -514,91 +514,92 @@ TEST_F(ShapeInferenceTest, MapThatChangesElementType) { TEST_F(ShapeInferenceTest, Map) { auto inferred_status_r1f32 = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie())); // It's OK to provide a single argument, as long as the applied arity matches // (this degenerates to a Map). auto inferred_status_r1f32_one = ShapeInference::InferMapShape( - {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32_one.status()); EXPECT_TRUE( ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie())); auto inferred_status_r2s32 = ShapeInference::InferMapShape( {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, - ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_)); + ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); EXPECT_IS_OK(inferred_status_r2s32.status()); EXPECT_TRUE( ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie())); auto no_args_error = ShapeInference::InferMapShape( - {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); ASSERT_FALSE(no_args_error.ok()); ASSERT_THAT(no_args_error.status().error_message(), HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(args_diff_shapes_error.ok()); ASSERT_THAT(args_diff_shapes_error.status().error_message(), HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( - {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), + {0}); ASSERT_FALSE(arity_error.ok()); ASSERT_THAT(arity_error.status().error_message(), HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0}); ASSERT_FALSE(output_shape_error.ok()); ASSERT_THAT(output_shape_error.status().error_message(), HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0}); ASSERT_FALSE(param_shape_error.ok()); ASSERT_THAT(param_shape_error.status().error_message(), HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0}); ASSERT_FALSE(param_element_type_error.ok()); ASSERT_THAT(param_element_type_error.status().error_message(), HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie())); auto inferred_status_error1 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); + {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); + {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), HasSubstr("parameter type has to match argument")); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index ac7c31bf68..6bdd9978fe 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -421,7 +421,8 @@ StatusOr<ComputationDataHandle> UserComputation::AddMapInstruction( to_apply_computation.ComputeProgramShape(to_apply_version)); TF_ASSIGN_OR_RETURN( Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape)); + ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, + map_request.dimensions())); ComputationDataHandle handle = CreateComputationDataHandle(); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 12b5e8426a..f66e3b57bf 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -176,7 +176,7 @@ TEST_F(ConvertTest, ConvertMapToS32) { auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->ConvertElementType(param, S32); auto a = builder.ConstantR1<float>({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError()); + builder.Map({a}, b->BuildAndNoteError(), {0}); std::vector<int32> expected = {42, 64}; ComputeAndCompareR1<int32>(&builder, expected, {}); @@ -188,7 +188,7 @@ TEST_F(ConvertTest, ConvertMapToF32) { auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->ConvertElementType(param, F32); auto a = builder.ConstantR1<int32>({42, 64}); - builder.Map({a}, b->BuildAndNoteError()); + builder.Map({a}, b->BuildAndNoteError(), {0}); std::vector<float> expected = {42.0f, 64.0f}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 01ee421baa..2ef392508d 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -125,7 +125,7 @@ class MapTest : public ClientLibraryTestBase { Computation CreateMapPlusN(const Computation& embedded_computation, float n) { ComputationBuilder builder(client_, TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto map = builder.Map({x}, embedded_computation); + auto map = builder.Map({x}, embedded_computation, {}); auto constant_n = builder.ConstantR0<float>(n); auto add = builder.Add(map, constant_n); auto computation_status = builder.Build(); @@ -173,7 +173,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {}); ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -187,7 +187,7 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -202,7 +202,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -216,7 +216,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne<int32>()); + auto map = builder.Map({param}, CreateScalarOne<int32>(), {0}); ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -229,7 +229,7 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne<uint32>()); + auto map = builder.Map({param}, CreateScalarOne<uint32>(), {0}); ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -243,7 +243,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOneTimesItself()); + auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1<float>( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -259,8 +259,8 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne()); - auto map2 = builder.Map({map1}, CreateMulByTwo()); + auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); + auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -276,8 +276,8 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne()); - auto map2 = builder.Map({map1}, CreateMulByTwo()); + auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); + auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -292,7 +292,7 @@ TEST_F(MapTest, MapEachElemPlusOneR2) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {0, 1}); Array2D<float> expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -319,8 +319,8 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { ComputationBuilder embed4_builder(client_, "embed4"); auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); - auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2); - auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3); + auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); + auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); @@ -331,8 +331,8 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { ComputationBuilder builder(client_, TestName()); auto constant_42 = builder.ConstantR0<float>(42.0); auto constant_7 = builder.ConstantR0<float>(7.0); - auto map_42 = builder.Map({constant_42}, embed5); - auto map_7 = builder.Map({constant_7}, embed4); + auto map_42 = builder.Map({constant_42}, embed5, {}); + auto map_7 = builder.Map({constant_7}, embed4, {}); builder.Add(map_42, map_7); ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f)); @@ -355,7 +355,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) { ComputationBuilder builder(client_, TestName()); auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0}); - auto map_plus_1 = builder.Map({constant_vector}, embedded_computation); + auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0}); // Add another Add(1) operation to the existing embedded computation. This // requires using the stub interface because the ComputationBuilder does not @@ -371,7 +371,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) { tensorflow::Status s = client_->stub()->Op(&op_request, &response); ASSERT_TRUE(s.ok()); - auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation); + auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0}); // The original vector has Add(1) applied to it with a map, followed by // Add(1+1) resulting in a net Add(3). @@ -393,8 +393,8 @@ TEST_F(MapTest, MapBinaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = - builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder)); + auto map = builder.Map({param0, param1}, + CreateScalarAddComputation(F32, &builder), {0}); ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -417,8 +417,8 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder)); + auto map = builder.Map({param0, param1}, + CreateScalarAddComputation(S32, &builder), {0, 1}); Array2D<int32> expected(2, 2); expected(0, 0) = 11; @@ -443,8 +443,8 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder)); + auto map = builder.Map({param0, param1}, + CreateScalarAddComputation(S32, &builder), {0, 1, 2}); ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -469,7 +469,7 @@ TEST_F(MapTest, MapTernaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder()); + auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1<float>( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -481,7 +481,7 @@ TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. ComputationBuilder b(client_, TestName()); auto gt = CreateGt(); - b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt); + b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt, {0}); ComputeAndCompareR1<bool>(&b, {false, true}, {}); } @@ -491,14 +491,14 @@ TEST_F(MapTest, NestedBinaryMap) { // max_with_square(x) = do max(x, x^2) via a map. ComputationBuilder b(client_, "max_with_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - b.Map({x, b.Mul(x, x)}, CreateMax()); + b.Map({x, b.Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } ComputationBuilder b(client_, TestName()); auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); - b.Map({input}, max_with_square); + b.Map({input}, max_with_square, {0}); ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); } @@ -525,7 +525,7 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, error_add); + auto map = builder.Map({param0, param1}, error_add, {0}); StatusOr<Computation> computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); @@ -562,7 +562,7 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, power); + builder.Map({param0, param1}, power, {}); ComputeAndCompareR0<float>(&builder, 32.0f, {param0_data.get(), param1_data.get()}, @@ -589,7 +589,7 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, sub_opposite); + builder.Map({param0, param1}, sub_opposite, {}); ComputeAndCompareR0<float>( &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); @@ -610,7 +610,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param0}, square); + builder.Map({param0}, square, {}); ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 4c33bb2c36..0fb87c3c2c 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -111,7 +111,7 @@ TEST_F(MatOpsSimpleTest, MapTwoByTwo) { {1.0, 0.0}, // row 0 {-1.0, 0.5}, // row 1 }); - auto map = builder.Map({data}, add_half); + auto map = builder.Map({data}, add_half, {0, 1}); std::unique_ptr<Literal> expected = Literal::CreateR2<float>({{1.5, 0.5}, // row 0 diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 0f82291fea..209f063cc5 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -170,7 +170,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); - builder.Map({param0}, fn); + builder.Map({param0}, fn, {0}); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 92efd2947d..6d063ffc36 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -117,7 +117,7 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ComputationBuilder mapper_builder(client_, TestName()); auto original = mapper_builder.ConstantR1<int32>({1, 2, 3}); - mapper_builder.Map({original}, plus_two); + mapper_builder.Map({original}, plus_two, {0}); Computation computation = mapper_builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 5533778947..4920f17a7e 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -293,7 +293,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { ComputationBuilder b(client_, TestName()); auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f}); - b.Map({input}, tuple_computation); + b.Map({input}, tuple_computation, {0}); ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 48a85f16a2..b52c718814 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -195,7 +195,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1<float>( {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Map({x, y}, add); + auto max = builder.Map({x, y}, add, {0}); std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; @@ -385,8 +385,8 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { auto two = builder.ConstantR0<float>(2.0); auto max = builder.Max(z_value, zero); auto mult = builder.Mul(two, max); - auto inner = builder.Map({mult}, add_half); - builder.Map({inner}, clamp); + auto inner = builder.Map({mult}, add_half, {}); + builder.Map({inner}, clamp, {}); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); mult_relu_add = computation_status.ConsumeValueOrDie(); @@ -396,7 +396,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { { auto x = builder.ConstantR1<float>( {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto activations = builder.Map({x}, mult_relu_add); + auto activations = builder.Map({x}, mult_relu_add, {0}); } std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7, diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 3327e06ed8..1771a3d5de 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -459,6 +459,11 @@ message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; repeated ComputationDataHandle static_operands = 4; + // The dimensions over which to map. + // Example mapping a Dot operation along the batch dimension 0: + // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] + // Map({operand0, operand1}, Dot, {0}) + repeated int64 dimensions = 5; } message ReduceRequest { diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 9cb27c7e95..4420a207c4 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -844,6 +844,7 @@ See also : : : T_1, ..., T_{N + M -1} -> S` : : : : with N parameters of type T : : : : and M of arbitrary type : +| `dimensions` | `int64` array | array of map dimensions | | `static_operands` | sequence of M | M arrays of arbitrary type | : : `ComputationDataHandle`s : : |