aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_layout.cc')
-rw-r--r--tensorflow/compiler/xla/shape_layout.cc8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
index 7ee366b27a..caad31d6ce 100644
--- a/tensorflow/compiler/xla/shape_layout.cc
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) {
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
}
+void ShapeLayout::ResetLayout(const Layout& layout,
+ ShapeIndexView shape_index) {
+ CHECK(ShapeUtil::IsTuple(shape_));
+ *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() =
+ layout;
+ TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
+}
+
bool ShapeLayout::operator==(const ShapeLayout& other) const {
return ShapeUtil::Equal(shape_, other.shape_);
}