aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shaped_buffer.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-06-06 15:42:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 15:46:30 -0700
commit05412bd367198ec491ca034b4bc634784c03125c (patch)
tree6b1e76ec79446337d55055dcc3ca6503bd7b345a /tensorflow/compiler/xla/service/shaped_buffer.cc
parent69c9365b4b71b9ab9663ee4f2a0fb226ce2fd26d (diff)
[XLA] Simplify Shape traversal visitors.
Simplify shape traversal visitors in ShapeUtil and ShapeTree. Add a non-Status form because most uses of the traversal methods do not use it, and remove is_leaf parameter from ShapeTree.ForEach* as it is not frequently used. PiperOrigin-RevId: 158201574
Diffstat (limited to 'tensorflow/compiler/xla/service/shaped_buffer.cc')
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc27
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index cf49fd72b7..865be1b84f 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -73,16 +73,13 @@ ShapedBuffer::MakeUnnestedTupleShapedBuffer(
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> shaped_buffer,
MakeShapedBuffer(shape, platform, device_ordinal));
- TF_CHECK_OK(shaped_buffer->mutable_shape_index_to_buffer_entry()
- ->ForEachMutableElement(
- [](const ShapeIndex& index, bool is_leaf,
- size_t* buffer_element) -> tensorflow::Status {
- if (is_leaf) {
- CHECK_EQ(index.size(), 1);
- *buffer_element = index[0];
- }
- return tensorflow::Status::OK();
- }));
+ shaped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement(
+ [&shaped_buffer](const ShapeIndex& index, size_t* buffer_element) {
+ if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) {
+ CHECK_EQ(index.size(), 1);
+ *buffer_element = index[0];
+ }
+ });
shaped_buffer->mutable_buffers()->reserve(buffers.size());
for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) {
shaped_buffer->mutable_buffers()->push_back(memory_base);
@@ -126,10 +123,12 @@ ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape,
// Allocate an appropriate sized buffer for each array element in the shape.
TF_RETURN_IF_ERROR(
- shaped_buffer->shape_index_to_buffer_entry_.ForEachMutableElement(
- [&shaped_buffer](const ShapeIndex& index, bool is_leaf,
- size_t* buffer_entry) -> tensorflow::Status {
- if (is_leaf) {
+ shaped_buffer->shape_index_to_buffer_entry_
+ .ForEachMutableElementWithStatus([&shaped_buffer](
+ const ShapeIndex& index,
+ size_t* buffer_entry)
+ -> tensorflow::Status {
+ if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) {
TF_ASSIGN_OR_RETURN(
perftools::gputools::DeviceMemoryBase memory_base,
shaped_buffer->allocator_->Allocate(