diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/scatter.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/scatter.cc | 58 |
1 files changed, 26 insertions, 32 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 45699233ea..d5a27abb25 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -30,24 +30,19 @@ limitations under the License. namespace tensorflow { -xla::StatusOr<xla::ComputationDataHandle> XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function<xla::ComputationDataHandle( - xla::ComputationDataHandle, xla::ComputationDataHandle, - xla::ComputationBuilder*)>& combiner, - xla::ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> buffer_shape, - builder->GetShape(buffer)); - TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> updates_shape, - builder->GetShape(updates)); - TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> indices_shape, - builder->GetShape(indices)); +xla::StatusOr<xla::XlaOp> XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>& + combiner, + xla::XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); + TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); gtl::ArraySlice<int64> indices_dims = - xla::AsInt64Slice(indices_shape->dimensions()); + xla::AsInt64Slice(indices_shape.dimensions()); gtl::ArraySlice<int64> buffer_dims = - xla::AsInt64Slice(buffer_shape->dimensions()); + xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -55,12 +50,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", - xla::ShapeUtil::HumanString(*indices_shape), + xla::ShapeUtil::HumanString(indices_shape), ") must be <= the rank of the buffer (shape: ", - xla::ShapeUtil::HumanString(*buffer_shape), ")"); + xla::ShapeUtil::HumanString(buffer_shape), ")"); } indices_dims.pop_back(); } @@ -78,10 +73,10 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter( // If any of the indexed dimensions are zero in the buffer, the update cannot // succeed since it updates a slice of size 1. for (int64 i = 0; i < num_index_dims; ++i) { - if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { - return errors::InvalidArgument( - "Scatter dimension ", i, " is of size zero in tensor with shape ", - xla::ShapeUtil::HumanString(*buffer_shape)); + if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { + return errors::InvalidArgument("Scatter dimension ", i, + " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(buffer_shape)); } } @@ -111,18 +106,17 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice<xla::ComputationDataHandle> loop_vars, - xla::ComputationBuilder* body_builder) { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars, + xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; auto buffer = loop_vars[2]; auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape->element_type())); + xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; + xla::XlaOp index; auto indices_offset = body_builder->Reshape(i, {1}); if (indices_are_vectors) { indices_offset = body_builder->Pad(indices_offset, zero_index, @@ -180,12 +174,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter( // Apply the update. buffer = body_builder->DynamicUpdateSlice(buffer, update, index); - return std::vector<xla::ComputationDataHandle>{indices, updates, buffer}; + return std::vector<xla::XlaOp>{indices, updates, buffer}; }; - TF_ASSIGN_OR_RETURN( - auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), - body_fn, init, "scatter", builder)); + TF_ASSIGN_OR_RETURN(auto outputs, + XlaForEachIndex(num_indices, indices_shape.element_type(), + body_fn, init, "scatter", builder)); return outputs[2]; } |