diff options
author | Mark Heffernan <meheff@google.com> | 2017-06-06 15:42:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-06 15:46:30 -0700 |
commit | 05412bd367198ec491ca034b4bc634784c03125c (patch) | |
tree | 6b1e76ec79446337d55055dcc3ca6503bd7b345a /tensorflow/compiler/xla/service/shaped_buffer.cc | |
parent | 69c9365b4b71b9ab9663ee4f2a0fb226ce2fd26d (diff) |
[XLA] Simplify Shape traversal visitors.
Simplify shape traversal visitors in ShapeUtil and ShapeTree. Add a non-Status form because most uses of the traversal methods do not use it, and remove is_leaf parameter from ShapeTree.ForEach* as it is not frequently used.
PiperOrigin-RevId: 158201574
Diffstat (limited to 'tensorflow/compiler/xla/service/shaped_buffer.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shaped_buffer.cc | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index cf49fd72b7..865be1b84f 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -73,16 +73,13 @@ ShapedBuffer::MakeUnnestedTupleShapedBuffer( } TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> shaped_buffer, MakeShapedBuffer(shape, platform, device_ordinal)); - TF_CHECK_OK(shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( - [](const ShapeIndex& index, bool is_leaf, - size_t* buffer_element) -> tensorflow::Status { - if (is_leaf) { - CHECK_EQ(index.size(), 1); - *buffer_element = index[0]; - } - return tensorflow::Status::OK(); - })); + shaped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement( + [&shaped_buffer](const ShapeIndex& index, size_t* buffer_element) { + if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { + CHECK_EQ(index.size(), 1); + *buffer_element = index[0]; + } + }); shaped_buffer->mutable_buffers()->reserve(buffers.size()); for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) { shaped_buffer->mutable_buffers()->push_back(memory_base); @@ -126,10 +123,12 @@ ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape, // Allocate an appropriate sized buffer for each array element in the shape. TF_RETURN_IF_ERROR( - shaped_buffer->shape_index_to_buffer_entry_.ForEachMutableElement( - [&shaped_buffer](const ShapeIndex& index, bool is_leaf, - size_t* buffer_entry) -> tensorflow::Status { - if (is_leaf) { + shaped_buffer->shape_index_to_buffer_entry_ + .ForEachMutableElementWithStatus([&shaped_buffer]( + const ShapeIndex& index, + size_t* buffer_entry) + -> tensorflow::Status { + if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { TF_ASSIGN_OR_RETURN( perftools::gputools::DeviceMemoryBase memory_base, shaped_buffer->allocator_->Allocate( |