aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 23:03:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 23:07:04 -0700
commite57874169fca3cfdd15cf0dda3717a6374a7dcb9 (patch)
tree8c84491ba2200c19d3a6291c26dea6f196ff33c4 /tensorflow/compiler
parent6795491bcc0c276e27be6a9e1a4a14c019c2ba37 (diff)
[XLA] Update Tf2Xla bridge to use Scatter HLO.
PiperOrigin-RevId: 215687800
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc213
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.h6
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc3
-rw-r--r--tensorflow/compiler/xla/service/inliner.cc32
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc30
6 files changed, 177 insertions, 110 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 38dfde165d..2b1c2ced92 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -38,12 +38,10 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
combiner,
xla::XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
- TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
+ TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates));
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
- absl::Span<const int64> buffer_dims =
- xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@@ -81,104 +79,129 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
}
}
- // Shape of the non-indexed dimensions of the buffer.
- std::vector<int64> buffer_shape_post_axes(
- buffer_dims.begin() + num_index_dims, buffer_dims.end());
-
- // Flatten the major dimensions of indices and updates into a single dimension
- // for ease of iteration.
- std::vector<int64> flat_indices_shape({num_indices});
- if (indices_are_vectors) {
- flat_indices_shape.push_back(num_index_dims);
+ // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of
+ // shape [3,3]:
+ // NOTE: ***This case will not be generated by any of the tf.scatter ops.***
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // updates = s32[3,2] parameter(2)
+ // scatter = s32[3,3] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={0},
+ // inserted_window_dims={1},
+ // scatter_dims_to_operand_dims={1},
+ // index_vector_dim=1
+ //
+ //
+ // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of
+ // shape [3,3]:
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // updates = s32[2,3] parameter(2)
+ // scatter = s32[3,3] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={1},
+ // inserted_window_dims={0},
+ // scatter_dims_to_operand_dims={0},
+ // index_vector_dim=1
+ //
+ //
+ // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of
+ // shape [3,3,2]
+ //
+ // operand = s32[3,3,2] parameter(0)
+ // indices = s32[2,2] parameter(1)
+ // updates = s32[2,2] parameter(2)
+ // scatter = s32[3,3,2] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={1},
+ // inserted_window_dims={0,1},
+ // scatter_dims_to_operand_dims={0,1},
+ // index_vector_dim=1
+ //
+ //
+ // Example of a scatter updating slices of shape [] in a tensor of shape [1,1]
+ //
+ // operand = s32[1,1] parameter(0)
+ // indices = s32[1] parameter(1)
+ // updates = s32[1] parameter(2)
+ // scatter = s32[1,1] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={},
+ // inserted_window_dims={0,1},
+ // scatter_dims_to_operand_dims={0},
+ // index_vector_dim=1
+ // Note that updates operand would be broadcasted into [1] in this case.
+ //
+
+ xla::ScatterDimensionNumbers dim_numbers;
+ dim_numbers.set_index_vector_dim(indices_are_vectors
+ ? indices_shape.dimensions_size() - 1
+ : indices_shape.dimensions_size());
+
+ int64 updates_rank = xla::ShapeUtil::Rank(updates_shape);
+ int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape);
+ int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
+
+ // If the rank of `updates` is 0 and does not match the expected rank of
+ // updates, broadcast `updates` to the expected shape of updates.
+ auto new_updates = updates;
+ std::vector<int64> expected_updates_dims(indices_dims.begin(),
+ indices_dims.end());
+ for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
+ expected_updates_dims.push_back(buffer_shape.dimensions(dim));
+ }
+ int64 expected_updates_rank = expected_updates_dims.size();
+ if (updates_rank == 0 && expected_updates_rank != 0) {
+ new_updates = xla::Broadcast(updates, expected_updates_dims);
+ TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
+ updates_rank = xla::ShapeUtil::Rank(updates_shape);
}
- std::vector<int64> flat_updates_shape({num_indices});
- flat_updates_shape.insert(flat_updates_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
-
- // Construct the initial values of the loop-carried Tensors.
- auto flat_indices = xla::Reshape(indices, flat_indices_shape);
- auto flat_updates = xla::Reshape(updates, flat_updates_shape);
- auto init = {flat_indices, flat_updates, buffer};
-
- // Constructs the loop body. The implementation of scatter is essentially:
- // for i in range(num_indices):
- // index = dynamic-slice(indices, i)
- // update = dynamic-slice(updates, i)
- // buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
- xla::XlaBuilder* body_builder) {
- auto indices = loop_vars[0];
- auto updates = loop_vars[1];
- auto buffer = loop_vars[2];
-
- auto zero_index = xla::ConstantLiteral(
- body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
-
- // Slice the i-th index from the indices array.
- xla::XlaOp index;
- auto indices_offset = xla::Reshape(i, {1});
- if (indices_are_vectors) {
- indices_offset = xla::Pad(indices_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, 1}}));
-
- index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims});
- index = xla::Collapse(index, {0, 1});
- } else {
- index = xla::DynamicSlice(indices, indices_offset, {1});
+ if (updates_rank > 0) {
+ for (int64 i = (updates_rank - num_window_dims_in_updates);
+ i < updates_rank; ++i) {
+ dim_numbers.add_update_window_dims(i);
}
+ }
- // Discard updates with negative indices, since some users expect this.
- auto index_in_range = xla::ReduceAll(
- xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
- xla::CreateScalarAndComputation(xla::PRED, body_builder));
-
- // Make the index in bounds to prevent implementation defined behavior.
- index = xla::Max(index, zero_index);
- index = xla::Pad(
- index, zero_index,
- xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
-
- // Slice the i-th index from the updates array.
- auto updates_offset = xla::Reshape(i, {1});
- updates_offset = xla::Pad(
- updates_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
- std::vector<int64> flat_updates_slice_shape({1});
- flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
- auto update =
- xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape);
-
- // Unflatten the major (iteration) dimensions of the slice to their
- // original shape.
- std::vector<int64> updates_slice_shape(num_index_dims, 1);
- updates_slice_shape.insert(updates_slice_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
- update = xla::Reshape(update, updates_slice_shape);
-
- // Apply the update to the buffer. If there is a combiner, use it to merge
- // the current values with the update.
- auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape);
+ for (int64 i = 0; i < num_index_dims; ++i) {
+ dim_numbers.add_inserted_window_dims(i);
+ dim_numbers.add_scatter_dims_to_operand_dims(i);
+ }
+
+ // Build the combiner computation.
+ xla::XlaComputation combiner_computation;
+ {
+ xla::XlaBuilder cb("scatter-combiner");
+ auto xla_scalar_shape =
+ xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {});
+ auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
+ auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1");
if (combiner) {
- update = combiner(current_value, update, body_builder);
+ combiner(p0, p1, &cb);
}
- // Use the current value instead of the update if the index is out of
- // bounds.
- update = xla::Select(index_in_range, update, current_value);
- // Apply the update.
- buffer = xla::DynamicUpdateSlice(buffer, update, index);
-
- return std::vector<xla::XlaOp>{indices, updates, buffer};
- };
-
- TF_ASSIGN_OR_RETURN(auto outputs,
- XlaForEachIndex(num_indices, indices_shape.element_type(),
- body_fn, init, "scatter", builder));
- return outputs[2];
+ combiner_computation = cb.Build().ConsumeValueOrDie();
+ }
+
+ VLOG(3) << "Scatter op:";
+ VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape);
+ VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape);
+ VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape);
+ VLOG(3) << " Scatter Dimension Numbers: ";
+ VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim();
+ VLOG(3) << " update_window_dims: ["
+ << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
+ VLOG(3) << " inserted_window_dims: ["
+ << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
+ VLOG(3) << " scatter_dims_to_operand_dims: ["
+ << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
+ << "]";
+
+ return xla::Scatter(buffer, indices, new_updates, combiner_computation,
+ dim_numbers);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h
index 13a5f1b850..4cf478c4b9 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.h
+++ b/tensorflow/compiler/tf2xla/lib/scatter.h
@@ -34,7 +34,11 @@ namespace tensorflow {
// Otherwise, `indices_are_vectors`, then indices are multidimensional and the
// minor dimension of `indices` represents a vector of indices.
//
-// If any indices are negative, the corresponding update is discarded.
+// If `updates` is a scalar, then it will be broadcasted into the expected shape
+// of updates.
+//
+// If any part of the update region is out-of-bounds, the corresponding update
+// is discarded.
//
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e0ec91dba1..d196252db1 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
+ case HloOpcode::kScatter:
+ // TODO(b/32495713): We aren't checking the embedded computation in
+ // Scatter.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 7527e35c95..93e04eb3db 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -146,7 +146,8 @@ void HloModule::ReplaceComputations(
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow: {
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter: {
HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
replacements, instruction->to_apply(), nullptr);
if (new_arg != nullptr) {
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc
index 5fd779ebf9..50c408f5bb 100644
--- a/tensorflow/compiler/xla/service/inliner.cc
+++ b/tensorflow/compiler/xla/service/inliner.cc
@@ -71,26 +71,23 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
// profitability model for inlining is defined.
if (hlo_query::AllOperandsAreParameters(root)) {
if (root.opcode() == HloOpcode::kFusion ||
- root.opcode() == HloOpcode::kParameter ||
root.opcode() == HloOpcode::kTrace) {
// Cloning not supported for these instructions.
return Status::OK();
}
VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
<< root.ToShortString();
- // If the input is a constant then the shape of the constant could be
- // different than the map shape. Hence, a broadcast is needed, else the
- // cloned operand with new shape and operands work.
- if (root.opcode() != HloOpcode::kConstant) {
- std::vector<HloInstruction*> params;
- for (int64 o = 0; o < root.operands().size(); o++) {
- params.push_back(map->operands()[root.operand(o)->parameter_number()]);
- }
- HloInstruction* placed_instruction = computation_->AddInstruction(
- root.CloneWithNewOperands(map->shape(), params));
+ if (root.opcode() == HloOpcode::kParameter) {
+ // If the root is a parameter, then use the corresponding operand as the
+ // result of the computation.
TF_RETURN_IF_ERROR(
- computation_->ReplaceInstruction(map, placed_instruction));
- } else {
+ map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
+ TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
+ } else if (root.opcode() == HloOpcode::kConstant) {
+ // If the input is a constant then the shape of the constant could be
+ // different than the map shape. Hence, a broadcast is needed, else the
+ // cloned operand with new shape and operands work.
+ //
// The constant is in an embedded computation and needs to be recreated
// as part of the computation that the broadcast is inserted into.
HloInstruction* constant = computation_->AddInstruction(root.Clone());
@@ -98,6 +95,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
HloInstruction::CreateBroadcast(map->shape(), constant, {}));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
+ } else {
+ std::vector<HloInstruction*> params;
+ for (int64 o = 0; o < root.operands().size(); o++) {
+ params.push_back(map->operands()[root.operand(o)->parameter_number()]);
+ }
+ HloInstruction* placed_instruction = computation_->AddInstruction(
+ root.CloneWithNewOperands(map->shape(), params));
+ TF_RETURN_IF_ERROR(
+ computation_->ReplaceInstruction(map, placed_instruction));
}
changed_ = true;
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 7e967f035c..98e0f2cfd7 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
+TEST_F(InlinerTest, MapParameter) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto param_builder = HloComputation::Builder(TestName());
+ param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
+ param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
+ auto param_f32 = param_builder.Build();
+
+ auto builder = HloComputation::Builder("MapParamFunction");
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = CreateNewVerifiedModule();
+ hlo_module->AddEmbeddedComputation(std::move(param_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ Inliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
+ auto expected = LiteralUtil::CreateR0<float>(4);
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
+}
} // namespace
} // namespace xla