aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shaped_buffer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shaped_buffer.cc')
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc27
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index cf49fd72b7..865be1b84f 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -73,16 +73,13 @@ ShapedBuffer::MakeUnnestedTupleShapedBuffer(
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> shaped_buffer,
MakeShapedBuffer(shape, platform, device_ordinal));
- TF_CHECK_OK(shaped_buffer->mutable_shape_index_to_buffer_entry()
- ->ForEachMutableElement(
- [](const ShapeIndex& index, bool is_leaf,
- size_t* buffer_element) -> tensorflow::Status {
- if (is_leaf) {
- CHECK_EQ(index.size(), 1);
- *buffer_element = index[0];
- }
- return tensorflow::Status::OK();
- }));
+ shaped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement(
+ [&shaped_buffer](const ShapeIndex& index, size_t* buffer_element) {
+ if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) {
+ CHECK_EQ(index.size(), 1);
+ *buffer_element = index[0];
+ }
+ });
shaped_buffer->mutable_buffers()->reserve(buffers.size());
for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) {
shaped_buffer->mutable_buffers()->push_back(memory_base);
@@ -126,10 +123,12 @@ ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape,
// Allocate an appropriate sized buffer for each array element in the shape.
TF_RETURN_IF_ERROR(
- shaped_buffer->shape_index_to_buffer_entry_.ForEachMutableElement(
- [&shaped_buffer](const ShapeIndex& index, bool is_leaf,
- size_t* buffer_entry) -> tensorflow::Status {
- if (is_leaf) {
+ shaped_buffer->shape_index_to_buffer_entry_
+ .ForEachMutableElementWithStatus([&shaped_buffer](
+ const ShapeIndex& index,
+ size_t* buffer_entry)
+ -> tensorflow::Status {
+ if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) {
TF_ASSIGN_OR_RETURN(
perftools::gputools::DeviceMemoryBase memory_base,
shaped_buffer->allocator_->Allocate(