From e57874169fca3cfdd15cf0dda3717a6374a7dcb9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 3 Oct 2018 23:03:11 -0700 Subject: [XLA] Update Tf2Xla bridge to use Scatter HLO. PiperOrigin-RevId: 215687800 --- tensorflow/compiler/tf2xla/lib/scatter.cc | 213 +++++++++++++----------- tensorflow/compiler/tf2xla/lib/scatter.h | 6 +- tensorflow/compiler/xla/client/xla_builder.cc | 3 + tensorflow/compiler/xla/service/hlo_module.cc | 3 +- tensorflow/compiler/xla/service/inliner.cc | 32 ++-- tensorflow/compiler/xla/service/inliner_test.cc | 30 ++++ 6 files changed, 177 insertions(+), 110 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 38dfde165d..2b1c2ced92 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -38,12 +38,10 @@ xla::StatusOr XlaScatter( combiner, xla::XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); - TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - absl::Span buffer_dims = - xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -81,104 +79,129 @@ xla::StatusOr XlaScatter( } } - // Shape of the non-indexed dimensions of the buffer. - std::vector buffer_shape_post_axes( - buffer_dims.begin() + num_index_dims, buffer_dims.end()); - - // Flatten the major dimensions of indices and updates into a single dimension - // for ease of iteration. - std::vector flat_indices_shape({num_indices}); - if (indices_are_vectors) { - flat_indices_shape.push_back(num_index_dims); + // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of + // shape [3,3]: + // NOTE: ***This case will not be generated by any of the tf.scatter ops.*** + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[3,2] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={0}, + // inserted_window_dims={1}, + // scatter_dims_to_operand_dims={1}, + // index_vector_dim=1 + // + // + // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of + // shape [3,3]: + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[2,3] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // + // + // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of + // shape [3,3,2] + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // updates = s32[2,2] parameter(2) + // scatter = s32[3,3,2] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0,1}, + // index_vector_dim=1 + // + // + // Example of a scatter updating slices of shape [] in a tensor of shape [1,1] + // + // operand = s32[1,1] parameter(0) + // indices = s32[1] parameter(1) + // updates = s32[1] parameter(2) + // scatter = s32[1,1] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // Note that updates operand would be broadcasted into [1] in this case. + // + + xla::ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(indices_are_vectors + ? indices_shape.dimensions_size() - 1 + : indices_shape.dimensions_size()); + + int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); + int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 num_window_dims_in_updates = buffer_rank - num_index_dims; + + // If the rank of `updates` is 0 and does not match the expected rank of + // updates, broadcast `updates` to the expected shape of updates. + auto new_updates = updates; + std::vector expected_updates_dims(indices_dims.begin(), + indices_dims.end()); + for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { + expected_updates_dims.push_back(buffer_shape.dimensions(dim)); + } + int64 expected_updates_rank = expected_updates_dims.size(); + if (updates_rank == 0 && expected_updates_rank != 0) { + new_updates = xla::Broadcast(updates, expected_updates_dims); + TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); + updates_rank = xla::ShapeUtil::Rank(updates_shape); } - std::vector flat_updates_shape({num_indices}); - flat_updates_shape.insert(flat_updates_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = xla::Reshape(indices, flat_indices_shape); - auto flat_updates = xla::Reshape(updates, flat_updates_shape); - auto init = {flat_indices, flat_updates, buffer}; - - // Constructs the loop body. The implementation of scatter is essentially: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // update = dynamic-slice(updates, i) - // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) { - auto indices = loop_vars[0]; - auto updates = loop_vars[1]; - auto buffer = loop_vars[2]; - - auto zero_index = xla::ConstantLiteral( - body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); - - // Slice the i-th index from the indices array. - xla::XlaOp index; - auto indices_offset = xla::Reshape(i, {1}); - if (indices_are_vectors) { - indices_offset = xla::Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - - index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = xla::Collapse(index, {0, 1}); - } else { - index = xla::DynamicSlice(indices, indices_offset, {1}); + if (updates_rank > 0) { + for (int64 i = (updates_rank - num_window_dims_in_updates); + i < updates_rank; ++i) { + dim_numbers.add_update_window_dims(i); } + } - // Discard updates with negative indices, since some users expect this. - auto index_in_range = xla::ReduceAll( - xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), - xla::CreateScalarAndComputation(xla::PRED, body_builder)); - - // Make the index in bounds to prevent implementation defined behavior. - index = xla::Max(index, zero_index); - index = xla::Pad( - index, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - - // Slice the i-th index from the updates array. - auto updates_offset = xla::Reshape(i, {1}); - updates_offset = xla::Pad( - updates_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - std::vector flat_updates_slice_shape({1}); - flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - auto update = - xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); - - // Unflatten the major (iteration) dimensions of the slice to their - // original shape. - std::vector updates_slice_shape(num_index_dims, 1); - updates_slice_shape.insert(updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - update = xla::Reshape(update, updates_slice_shape); - - // Apply the update to the buffer. If there is a combiner, use it to merge - // the current values with the update. - auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); + for (int64 i = 0; i < num_index_dims; ++i) { + dim_numbers.add_inserted_window_dims(i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + + // Build the combiner computation. + xla::XlaComputation combiner_computation; + { + xla::XlaBuilder cb("scatter-combiner"); + auto xla_scalar_shape = + xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); + auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0"); + auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1"); if (combiner) { - update = combiner(current_value, update, body_builder); + combiner(p0, p1, &cb); } - // Use the current value instead of the update if the index is out of - // bounds. - update = xla::Select(index_in_range, update, current_value); - // Apply the update. - buffer = xla::DynamicUpdateSlice(buffer, update, index); - - return std::vector{indices, updates, buffer}; - }; - - TF_ASSIGN_OR_RETURN(auto outputs, - XlaForEachIndex(num_indices, indices_shape.element_type(), - body_fn, init, "scatter", builder)); - return outputs[2]; + combiner_computation = cb.Build().ConsumeValueOrDie(); + } + + VLOG(3) << "Scatter op:"; + VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); + VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape); + VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); + VLOG(3) << " Scatter Dimension Numbers: "; + VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); + VLOG(3) << " update_window_dims: [" + << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]"; + VLOG(3) << " inserted_window_dims: [" + << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]"; + VLOG(3) << " scatter_dims_to_operand_dims: [" + << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",") + << "]"; + + return xla::Scatter(buffer, indices, new_updates, combiner_computation, + dim_numbers); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 13a5f1b850..4cf478c4b9 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -34,7 +34,11 @@ namespace tensorflow { // Otherwise, `indices_are_vectors`, then indices are multidimensional and the // minor dimension of `indices` represents a vector of indices. // -// If any indices are negative, the corresponding update is discarded. +// If `updates` is a scalar, then it will be broadcasted into the expected shape +// of updates. +// +// If any part of the update region is out-of-bounds, the corresponding update +// is discarded. // // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index e0ec91dba1..d196252db1 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + case HloOpcode::kScatter: + // TODO(b/32495713): We aren't checking the embedded computation in + // Scatter. case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 7527e35c95..93e04eb3db 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -146,7 +146,8 @@ void HloModule::ReplaceComputations( case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: { + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 5fd779ebf9..50c408f5bb 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -71,26 +71,23 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. return Status::OK(); } VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " << root.ToShortString(); - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - if (root.opcode() != HloOpcode::kConstant) { - std::vector params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // // The constant is in an embedded computation and needs to be recreated // as part of the computation that the broadcast is inserted into. HloInstruction* constant = computation_->AddInstruction(root.Clone()); @@ -98,6 +95,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast(map->shape(), constant, {})); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7e967f035c..98e0f2cfd7 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } +TEST_F(InlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} } // namespace } // namespace xla -- cgit v1.2.3