aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-12-17 21:10:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-18 11:31:21 -0800
commite7bee822fe6bfa7ca861b09ba10769339692cf86 (patch)
tree3937acfd4df6b9f911f514a65661f42b60da1815 /tensorflow/compiler/xla/service/layout_assignment.cc
parentd2355fcee9f47cc2e8225f8ff54f7c12fa8045f0 (diff)
[XLA] Don't reimplement ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout inside of layout_assignment.
No functional change. PiperOrigin-RevId: 179376105
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc19
1 files changed, 8 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index af726271ae..b598c765fc 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -477,16 +477,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
/*mandatory=*/true));
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
// Add constraints for kCustomCall instruction operands and instructions.
- // For now we only support row major layouts for all inputs and outputs.
- auto row_major_shape = [](const Shape& old_shape) {
- Shape new_shape(old_shape);
- std::vector<int64> dimension_order(new_shape.dimensions_size());
- std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
- *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
- return new_shape;
- };
-
- Shape result_shape(row_major_shape(instruction->shape()));
+ // For now we only support major-first layouts for all inputs and outputs.
+ Shape result_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ instruction->shape().element_type(),
+ AsInt64Slice(instruction->shape().dimensions()));
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(result_shape, instruction));
for (int64 i = 0; i < instruction->operand_count(); ++i) {
@@ -496,7 +490,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
continue;
}
- Shape row_major_operand_shape(row_major_shape(operand_shape));
+ Shape row_major_operand_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ operand_shape.element_type(),
+ AsInt64Slice(operand_shape.dimensions()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
row_major_operand_shape, instruction, i, /*mandatory=*/true));
}