aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 37834e1cc2..f7812d9661 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1705,6 +1705,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
reshape, HloInstruction::CreateReshape(reshape->shape(),
operand->mutable_operand(0)));
}
+ if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
+ *operand->mutable_shape() = reshape->shape();
+ return ReplaceInstruction(reshape, operand);
+ }
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
@@ -2144,6 +2148,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
transpose->dimensions())));
}
+ if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
+ *operand->mutable_shape() = transpose->shape();
+ return ReplaceInstruction(transpose, operand);
+ }
+
if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) {
ReplaceWithBitcast(transpose);
return Status::OK();