diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_layout.cc')
-rw-r--r-- | tensorflow/compiler/xla/shape_layout.cc | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 7ee366b27a..caad31d6ce 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); } +void ShapeLayout::ResetLayout(const Layout& layout, + ShapeIndexView shape_index) { + CHECK(ShapeUtil::IsTuple(shape_)); + *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() = + layout; + TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); +} + bool ShapeLayout::operator==(const ShapeLayout& other) const { return ShapeUtil::Equal(shape_, other.shape_); } |