aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/scatter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/scatter.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc58
1 files changed, 26 insertions, 32 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 45699233ea..d5a27abb25 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -30,24 +30,19 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
- const xla::ComputationDataHandle& buffer,
- const xla::ComputationDataHandle& updates,
- const xla::ComputationDataHandle& indices, bool indices_are_vectors,
- const std::function<xla::ComputationDataHandle(
- xla::ComputationDataHandle, xla::ComputationDataHandle,
- xla::ComputationBuilder*)>& combiner,
- xla::ComputationBuilder* builder) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> buffer_shape,
- builder->GetShape(buffer));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> updates_shape,
- builder->GetShape(updates));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> indices_shape,
- builder->GetShape(indices));
+xla::StatusOr<xla::XlaOp> XlaScatter(
+ const xla::XlaOp& buffer, const xla::XlaOp& updates,
+ const xla::XlaOp& indices, bool indices_are_vectors,
+ const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
+ combiner,
+ xla::XlaBuilder* builder) {
+ TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
+ TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
+ TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
gtl::ArraySlice<int64> indices_dims =
- xla::AsInt64Slice(indices_shape->dimensions());
+ xla::AsInt64Slice(indices_shape.dimensions());
gtl::ArraySlice<int64> buffer_dims =
- xla::AsInt64Slice(buffer_shape->dimensions());
+ xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@@ -55,12 +50,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
if (indices_are_vectors) {
TF_RET_CHECK(!indices_dims.empty());
num_index_dims = indices_dims.back();
- if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) {
+ if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) {
return errors::InvalidArgument(
"The size of the minor dimension of the indices (shape: ",
- xla::ShapeUtil::HumanString(*indices_shape),
+ xla::ShapeUtil::HumanString(indices_shape),
") must be <= the rank of the buffer (shape: ",
- xla::ShapeUtil::HumanString(*buffer_shape), ")");
+ xla::ShapeUtil::HumanString(buffer_shape), ")");
}
indices_dims.pop_back();
}
@@ -78,10 +73,10 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// If any of the indexed dimensions are zero in the buffer, the update cannot
// succeed since it updates a slice of size 1.
for (int64 i = 0; i < num_index_dims; ++i) {
- if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) {
- return errors::InvalidArgument(
- "Scatter dimension ", i, " is of size zero in tensor with shape ",
- xla::ShapeUtil::HumanString(*buffer_shape));
+ if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) {
+ return errors::InvalidArgument("Scatter dimension ", i,
+ " is of size zero in tensor with shape ",
+ xla::ShapeUtil::HumanString(buffer_shape));
}
}
@@ -111,18 +106,17 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
- xla::ComputationBuilder* body_builder) {
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
auto zero_index = body_builder->ConstantLiteral(
- xla::Literal::Zero(indices_shape->element_type()));
+ xla::Literal::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
- xla::ComputationDataHandle index;
+ xla::XlaOp index;
auto indices_offset = body_builder->Reshape(i, {1});
if (indices_are_vectors) {
indices_offset = body_builder->Pad(indices_offset, zero_index,
@@ -180,12 +174,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// Apply the update.
buffer = body_builder->DynamicUpdateSlice(buffer, update, index);
- return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
+ return std::vector<xla::XlaOp>{indices, updates, buffer};
};
- TF_ASSIGN_OR_RETURN(
- auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
- body_fn, init, "scatter", builder));
+ TF_ASSIGN_OR_RETURN(auto outputs,
+ XlaForEachIndex(num_indices, indices_shape.element_type(),
+ body_fn, init, "scatter", builder));
return outputs[2];
}