aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc4
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc12
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc21
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc3
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc8
-rw-r--r--tensorflow/compiler/xla/xla_data.proto5
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md1
17 files changed, 105 insertions, 64 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 210a4d95b9..a80412e951 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -1307,6 +1307,7 @@ StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
ComputationDataHandle ComputationBuilder::Map(
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
@@ -1317,6 +1318,9 @@ ComputationDataHandle ComputationBuilder::Map(
*request.add_operands() = operand;
}
*request.mutable_to_apply() = computation.handle();
+ for (int64 dimension : dimensions) {
+ request.add_dimensions(dimension);
+ }
for (const ComputationDataHandle& sop : static_operands) {
*request.add_static_operands() = sop;
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index b0e6720be2..73972c1290 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -604,6 +604,7 @@ class ComputationBuilder {
ComputationDataHandle Map(
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 0a288a77ad..0eaa21ef25 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -169,7 +169,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
TEST_F(HloCostAnalysisTest, Map) {
ComputationBuilder builder(client_, "map");
auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
- auto result = builder.Map({input}, add_and_exp_);
+ auto result = builder.Map({input}, add_and_exp_, {0});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -286,7 +286,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
// sigmoid(input * weight + bias)
auto result = builder.Map(
- {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_);
+ {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 2405d44778..c16747c02c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -241,12 +241,20 @@ class ShapeVerifier : public DfsHloVisitor {
HloComputation* function,
tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) override {
std::vector<const Shape*> operand_shapes;
+ int64 max_operand_rank = 0;
for (const HloInstruction* operand : operands) {
operand_shapes.push_back(&operand->shape());
+ max_operand_rank =
+ std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
}
+ // TODO(b/65689298) Remove code below once Map is generalized to accept
+ // arbitrary map dimensions.
+ std::vector<int64> map_dims(max_operand_rank);
+ std::iota(map_dims.begin(), map_dims.end(), 0);
return CheckShape(
- map, ShapeInference::InferMapShape(
- operand_shapes, map->to_apply()->ComputeProgramShape()));
+ map,
+ ShapeInference::InferMapShape(
+ operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims));
}
Status HandleReduceWindow(HloInstruction* reduce_window,
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 5178a750b9..23c8266e77 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -852,7 +852,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply) {
+ const ProgramShape& to_apply,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
if (arg_shapes.empty()) {
return InvalidArgument("Map expects at least one argument");
}
@@ -888,6 +889,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
tensorflow::str_util::Join(pieces, ", ").c_str());
}
+ // Check that dimensions.size == arg_shape.dimensions_size() (we currently
+ // only support mapping across all dimensions: i.e. scalar map functions).
+ if (dimensions.size() != arg_shape->dimensions_size()) {
+ return InvalidArgument(
+ "Map applied to a subset of dimensions currently not supported: "
+ "arg_dimension_size: %d, requested_map_dimensions_size: %zu",
+ arg_shape->dimensions_size(), dimensions.size());
+ }
+
+ // Check that requested map dimensions numbers are monotonically increasing.
+ for (int i = 0; i < dimensions.size(); ++i) {
+ if (dimensions[i] != i) {
+ return InvalidArgument(
+ "Map requires monotonically increasing dimension numbers, found: %s ",
+ tensorflow::str_util::Join(dimensions, ", ").c_str());
+ }
+ }
+
// The applied function's arity equals the number of arguments.
if (arg_shapes.size() != to_apply.parameters_size()) {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 379feef5e4..d5d497176d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -78,7 +78,8 @@ class ShapeInference {
// to the given operand shapes.
static StatusOr<Shape> InferMapShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply);
+ const ProgramShape& to_apply,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
// Infers the shape produced by InferBatchNormTraining with the given
// operands.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 8c731ae297..7c9c7e8d6a 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -505,7 +505,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
Shape arg = ShapeUtil::MakeShape(F32, {20});
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
- auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply);
+ auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
EXPECT_IS_OK(inferred_status.status());
Shape expected = ShapeUtil::MakeShape(S32, {20});
EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
@@ -514,91 +514,92 @@ TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
TEST_F(ShapeInferenceTest, Map) {
auto inferred_status_r1f32 = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
EXPECT_IS_OK(inferred_status_r1f32.status());
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
// It's OK to provide a single argument, as long as the applied arity matches
// (this degenerates to a Map).
auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
- {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
EXPECT_IS_OK(inferred_status_r1f32_one.status());
EXPECT_TRUE(
ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
auto inferred_status_r2s32 = ShapeInference::InferMapShape(
{&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
- ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_));
+ ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
EXPECT_IS_OK(inferred_status_r2s32.status());
EXPECT_TRUE(
ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
auto no_args_error = ShapeInference::InferMapShape(
- {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
ASSERT_FALSE(no_args_error.ok());
ASSERT_THAT(no_args_error.status().error_message(),
HasSubstr("expects at least one argument"));
auto args_diff_shapes_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_64_},
- ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
ASSERT_FALSE(args_diff_shapes_error.ok());
ASSERT_THAT(args_diff_shapes_error.status().error_message(),
HasSubstr("requires all operands to have the same shape"));
auto arity_error = ShapeInference::InferMapShape(
- {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
+ {0});
ASSERT_FALSE(arity_error.ok());
ASSERT_THAT(arity_error.status().error_message(),
HasSubstr("function arity must match"));
auto output_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_));
+ ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
ASSERT_FALSE(output_shape_error.ok());
ASSERT_THAT(output_shape_error.status().error_message(),
HasSubstr("result has to be a scalar"));
auto param_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_));
+ ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
ASSERT_FALSE(param_shape_error.ok());
ASSERT_THAT(param_shape_error.status().error_message(),
HasSubstr("parameter has to be a scalar"));
auto param_element_type_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
- ShapeUtil::MakeProgramShape({f32_, s32_}, f32_));
+ ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
ASSERT_FALSE(param_element_type_error.ok());
ASSERT_THAT(param_element_type_error.status().error_message(),
HasSubstr("parameter type has to match argument"));
Shape arg = ShapeUtil::MakeShape(F32, {20});
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
- auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply);
+ auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
EXPECT_IS_OK(inferred_status.status());
EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
auto inferred_status_error1 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("arity must match number of arguments"));
auto inferred_status_error2 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_));
+ {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("has to be a scalar"));
auto inferred_status_error3 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_));
+ {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("has to be a scalar"));
auto inferred_status_error5 = ShapeInference::InferMapShape(
- {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_));
+ {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(inferred_status_error5.status().error_message(),
HasSubstr("parameter type has to match argument"));
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index ac7c31bf68..6bdd9978fe 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -421,7 +421,8 @@ StatusOr<ComputationDataHandle> UserComputation::AddMapInstruction(
to_apply_computation.ComputeProgramShape(to_apply_version));
TF_ASSIGN_OR_RETURN(
Shape inferred_shape,
- ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape));
+ ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape,
+ map_request.dimensions()));
ComputationDataHandle handle = CreateComputationDataHandle();
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 12b5e8426a..f66e3b57bf 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -176,7 +176,7 @@ TEST_F(ConvertTest, ConvertMapToS32) {
auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
b->ConvertElementType(param, S32);
auto a = builder.ConstantR1<float>({42.0f, 64.0f});
- builder.Map({a}, b->BuildAndNoteError());
+ builder.Map({a}, b->BuildAndNoteError(), {0});
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -188,7 +188,7 @@ TEST_F(ConvertTest, ConvertMapToF32) {
auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
b->ConvertElementType(param, F32);
auto a = builder.ConstantR1<int32>({42, 64});
- builder.Map({a}, b->BuildAndNoteError());
+ builder.Map({a}, b->BuildAndNoteError(), {0});
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 01ee421baa..2ef392508d 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -125,7 +125,7 @@ class MapTest : public ClientLibraryTestBase {
Computation CreateMapPlusN(const Computation& embedded_computation, float n) {
ComputationBuilder builder(client_, TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto map = builder.Map({x}, embedded_computation);
+ auto map = builder.Map({x}, embedded_computation, {});
auto constant_n = builder.ConstantR0<float>(n);
auto add = builder.Add(map, constant_n);
auto computation_status = builder.Build();
@@ -173,7 +173,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateAdderToOne());
+ auto map = builder.Map({param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
ErrorSpec(0.01f));
@@ -187,7 +187,7 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateAdderToOne());
+ auto map = builder.Map({param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -202,7 +202,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateAdderToOne());
+ auto map = builder.Map({param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
{param0_data.get()}, ErrorSpec(0.01f));
@@ -216,7 +216,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateScalarOne<int32>());
+ auto map = builder.Map({param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
}
@@ -229,7 +229,7 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateScalarOne<uint32>());
+ auto map = builder.Map({param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
}
@@ -243,7 +243,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateAdderToOneTimesItself());
+ auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
&builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
@@ -259,8 +259,8 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map1 = builder.Map({param}, CreateAdderToOne());
- auto map2 = builder.Map({map1}, CreateMulByTwo());
+ auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
+ auto map2 = builder.Map({map1}, CreateMulByTwo(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -276,8 +276,8 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map1 = builder.Map({param}, CreateAdderToOne());
- auto map2 = builder.Map({map1}, CreateMulByTwo());
+ auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
+ auto map2 = builder.Map({map1}, CreateMulByTwo(), {0});
ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
{param0_data.get()}, ErrorSpec(0.01f));
@@ -292,7 +292,7 @@ TEST_F(MapTest, MapEachElemPlusOneR2) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map = builder.Map({param}, CreateAdderToOne());
+ auto map = builder.Map({param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
{{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
@@ -319,8 +319,8 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
ComputationBuilder embed4_builder(client_, "embed4");
auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x");
- auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2);
- auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3);
+ auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {});
+ auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {});
auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs);
auto embed4_status = embed4_builder.Build();
ASSERT_IS_OK(embed4_status.status());
@@ -331,8 +331,8 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
ComputationBuilder builder(client_, TestName());
auto constant_42 = builder.ConstantR0<float>(42.0);
auto constant_7 = builder.ConstantR0<float>(7.0);
- auto map_42 = builder.Map({constant_42}, embed5);
- auto map_7 = builder.Map({constant_7}, embed4);
+ auto map_42 = builder.Map({constant_42}, embed5, {});
+ auto map_7 = builder.Map({constant_7}, embed4, {});
builder.Add(map_42, map_7);
ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
@@ -355,7 +355,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) {
ComputationBuilder builder(client_, TestName());
auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto map_plus_1 = builder.Map({constant_vector}, embedded_computation);
+ auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0});
// Add another Add(1) operation to the existing embedded computation. This
// requires using the stub interface because the ComputationBuilder does not
@@ -371,7 +371,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) {
tensorflow::Status s = client_->stub()->Op(&op_request, &response);
ASSERT_TRUE(s.ok());
- auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation);
+ auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0});
// The original vector has Add(1) applied to it with a map, followed by
// Add(1+1) resulting in a net Add(3).
@@ -393,8 +393,8 @@ TEST_F(MapTest, MapBinaryAdder) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto map =
- builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder));
+ auto map = builder.Map({param0, param1},
+ CreateScalarAddComputation(F32, &builder), {0});
ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
{param0_data.get(), param1_data.get()},
@@ -417,8 +417,8 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto map =
- builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder));
+ auto map = builder.Map({param0, param1},
+ CreateScalarAddComputation(S32, &builder), {0, 1});
Array2D<int32> expected(2, 2);
expected(0, 0) = 11;
@@ -443,8 +443,8 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto map =
- builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder));
+ auto map = builder.Map({param0, param1},
+ CreateScalarAddComputation(S32, &builder), {0, 1, 2});
ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
{param0_data.get(), param1_data.get()});
@@ -469,7 +469,7 @@ TEST_F(MapTest, MapTernaryAdder) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
auto param2 = builder.Parameter(2, param2_literal->shape(), "param2");
- auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder());
+ auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
&builder, {-2.7f, -92.3f, -895.7f, -400.0f},
@@ -481,7 +481,7 @@ TEST_F(MapTest, MapGt) {
// Maps (x,y) -> x > y onto two R1F32 vectors.
ComputationBuilder b(client_, TestName());
auto gt = CreateGt();
- b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt);
+ b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt, {0});
ComputeAndCompareR1<bool>(&b, {false, true}, {});
}
@@ -491,14 +491,14 @@ TEST_F(MapTest, NestedBinaryMap) {
// max_with_square(x) = do max(x, x^2) via a map.
ComputationBuilder b(client_, "max_with_square");
auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- b.Map({x, b.Mul(x, x)}, CreateMax());
+ b.Map({x, b.Mul(x, x)}, CreateMax(), {});
auto computation_status = b.Build();
ASSERT_IS_OK(computation_status.status());
max_with_square = computation_status.ConsumeValueOrDie();
}
ComputationBuilder b(client_, TestName());
auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
- b.Map({input}, max_with_square);
+ b.Map({input}, max_with_square, {0});
ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
}
@@ -525,7 +525,7 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto map = builder.Map({param0, param1}, error_add);
+ auto map = builder.Map({param0, param1}, error_add, {0});
StatusOr<Computation> computation_status = builder.Build();
ASSERT_TRUE(!computation_status.ok());
@@ -562,7 +562,7 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, power);
+ builder.Map({param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
{param0_data.get(), param1_data.get()},
@@ -589,7 +589,7 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, sub_opposite);
+ builder.Map({param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
&builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
@@ -610,7 +610,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param0}, square);
+ builder.Map({param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
ErrorSpec(0.01f));
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 4c33bb2c36..0fb87c3c2c 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -111,7 +111,7 @@ TEST_F(MatOpsSimpleTest, MapTwoByTwo) {
{1.0, 0.0}, // row 0
{-1.0, 0.5}, // row 1
});
- auto map = builder.Map({data}, add_half);
+ auto map = builder.Map({data}, add_half, {0, 1});
std::unique_ptr<Literal> expected =
Literal::CreateR2<float>({{1.5, 0.5}, // row 0
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 0f82291fea..209f063cc5 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -170,7 +170,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto fn = build_sum_rng(builder);
- builder.Map({param0}, fn);
+ builder.Map({param0}, fn, {0});
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index 92efd2947d..6d063ffc36 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -117,7 +117,7 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
ComputationBuilder mapper_builder(client_, TestName());
auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
- mapper_builder.Map({original}, plus_two);
+ mapper_builder.Map({original}, plus_two, {0});
Computation computation = mapper_builder.Build().ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 5533778947..4920f17a7e 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -293,7 +293,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
ComputationBuilder b(client_, TestName());
auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f});
- b.Map({input}, tuple_computation);
+ b.Map({input}, tuple_computation, {0});
ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index 48a85f16a2..b52c718814 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -195,7 +195,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR1<float>(
{-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto max = builder.Map({x, y}, add);
+ auto max = builder.Map({x, y}, add, {0});
std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9,
0.1, -6.8, 4., -1., 2.2};
@@ -385,8 +385,8 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
auto two = builder.ConstantR0<float>(2.0);
auto max = builder.Max(z_value, zero);
auto mult = builder.Mul(two, max);
- auto inner = builder.Map({mult}, add_half);
- builder.Map({inner}, clamp);
+ auto inner = builder.Map({mult}, add_half, {});
+ builder.Map({inner}, clamp, {});
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
mult_relu_add = computation_status.ConsumeValueOrDie();
@@ -396,7 +396,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
{
auto x = builder.ConstantR1<float>(
{2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto activations = builder.Map({x}, mult_relu_add);
+ auto activations = builder.Map({x}, mult_relu_add, {0});
}
std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7,
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 3327e06ed8..1771a3d5de 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -459,6 +459,11 @@ message MapRequest {
repeated ComputationDataHandle operands = 2;
ComputationHandle to_apply = 3;
repeated ComputationDataHandle static_operands = 4;
+ // The dimensions over which to map.
+ // Example mapping a Dot operation along the batch dimension 0:
+ // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3]
+ // Map({operand0, operand1}, Dot, {0})
+ repeated int64 dimensions = 5;
}
message ReduceRequest {
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 9cb27c7e95..4420a207c4 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -844,6 +844,7 @@ See also
: : : T_1, ..., T_{N + M -1} -> S` :
: : : with N parameters of type T :
: : : and M of arbitrary type :
+| `dimensions` | `int64` array | array of map dimensions |
| `static_operands` | sequence of M | M arrays of arbitrary type |
: : `ComputationDataHandle`s : :