diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-05 15:46:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-05 15:51:07 -0800 |
commit | 5476489053f0523b8aebab05bc39a02c089300e0 (patch) | |
tree | 93f411aaee012c7884f47d3310995057cec09ef1 | |
parent | 2271f0f8c463a01af86c9e17be38e3cfc12eae11 (diff) |
[XLA] Sink layout sensitivity from CSE into HloInstruction::Identical, and make it the default.
PiperOrigin-RevId: 184598903
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cse.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 27 |
3 files changed, 31 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 7feda2b3b0..279edd4ba8 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -119,9 +119,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) { equivalent_instructions; for (HloInstruction* user : operand->users()) { if (user != instruction && - user->Identical(*instruction, eq_instructions, eq_computations) && - (!is_layout_sensitive_ || - ShapeUtil::Equal(user->shape(), instruction->shape()))) { + user->Identical(*instruction, eq_instructions, eq_computations, + is_layout_sensitive_)) { equivalent_instructions.push_back(user); } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index fac6b43405..277648f072 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1612,7 +1612,8 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const { + eq_computations, + const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and @@ -1671,7 +1672,7 @@ bool HloInstruction::IdenticalSlowPath( return parameter_number() == other.parameter_number() && // Check the shape too because `this` and `other` may be in // different HloComputations. - ShapeUtil::Compatible(shape(), other.shape()); + eq_shapes(shape(), other.shape()); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1727,18 +1728,18 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals(window(), other.window()); case HloOpcode::kReshape: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); // Transpose result is determined by the final shape and the permutation. case HloOpcode::kTranspose: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); // Remaining instructions with special values. case HloOpcode::kBitcast: - return ShapeUtil::Equal(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kBroadcast: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); case HloOpcode::kConcatenate: return dimensions() == other.dimensions(); @@ -1752,10 +1753,10 @@ bool HloInstruction::IdenticalSlowPath( slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; case HloOpcode::kDynamicSlice: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; case HloOpcode::kDynamicUpdateSlice: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index bce9ebdda8..50931c563a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -554,27 +554,36 @@ class HloInstruction { } // Returns true if "other" performs the same computation as this instruction. - // Layout of the instructions' output array is not considered. bool Identical( const HloInstruction& other, const std::function<bool(const HloInstruction*, const HloInstruction*)>& eq_operands = std::equal_to<const HloInstruction*>(), const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations = std::equal_to<const HloComputation*>()) const { + eq_computations = std::equal_to<const HloComputation*>(), + bool layout_sensitive = true) const { // An instruction is always identical to itself. if (this == &other) { return true; } - // Identical instruction must have the same opcode and identical operands. - // In general, there is no need to check shape because shape is inferred - // from the shape of the operands. + // Identical instruction must have the same opcode, shape, and identical + // operands. if (opcode() != other.opcode()) { return false; } + auto eq_shapes = layout_sensitive + ? [](const Shape& a, + const Shape& b) { return ShapeUtil::Equal(a, b); } + : [](const Shape& a, const Shape& b) { + return ShapeUtil::Compatible(a, b); + }; + if (!eq_shapes(shape(), other.shape())) { + return false; + } if (operands().size() != other.operands().size()) { return false; } + // Use an explicit loop rather than ContainerEquals, because copying around // std::functions may be too expensive in some cases. for (size_t i = 0; i < operands().size(); ++i) { @@ -583,7 +592,7 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations); + return IdenticalSlowPath(other, eq_computations, eq_shapes); } // Returns whether the instruction has a constant operand. @@ -1232,10 +1241,14 @@ class HloInstruction { class FusionReusesParamElements; // See comments on Identical(). + // eq_shapes() is used to check shapes for equality, and would normally be + // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on + // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const; + eq_computations, + const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const; // Creates an n-ary elementwise operation. static std::unique_ptr<HloInstruction> CreateNary( |