aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-03 20:48:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 20:52:28 -0700
commit2e19f32d28ab88b5bd3dd4f6d42a54040591dfbb (patch)
tree4bc094affc575d865e4588c6216dbcf99c98bdb1 /tensorflow/compiler
parent9bd6f5ed55e533ccac055a5bc7fbb771e2d432c5 (diff)
[XLA] Fix handling of tuple constants in HLO constant folding.
PiperOrigin-RevId: 215676675
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc26
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 538816a353..4f898ce61c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -77,19 +77,23 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
}
// Don't constant fold unless it's a net positive or the output is small.
- int64 elements_in_removed_operands = 0;
- for (HloInstruction* operand : instruction->operands()) {
- if (operand->user_count() == 1) {
- elements_in_removed_operands +=
- ShapeUtil::ElementsIn(operand->shape());
+ if (ShapeUtil::IsArray(instruction->shape())) {
+ int64 elements_in_removed_operands = 0;
+ for (HloInstruction* operand : instruction->operands()) {
+ if (operand->user_count() == 1 &&
+ ShapeUtil::IsArray(operand->shape())) {
+ elements_in_removed_operands +=
+ ShapeUtil::ElementsIn(operand->shape());
+ }
}
- }
- int64 elements_in_constant = ShapeUtil::ElementsIn(instruction->shape());
+ int64 elements_in_constant =
+ ShapeUtil::ElementsIn(instruction->shape());
- static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000;
- if (elements_in_constant > elements_in_removed_operands &&
- elements_in_constant > kMaximumConstantSizeElements) {
- continue;
+ static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000;
+ if (elements_in_constant > elements_in_removed_operands &&
+ elements_in_constant > kMaximumConstantSizeElements) {
+ continue;
+ }
}
Literal result;