aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-05 15:46:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 15:51:07 -0800
commit5476489053f0523b8aebab05bc39a02c089300e0 (patch)
tree93f411aaee012c7884f47d3310995057cec09ef1
parent2271f0f8c463a01af86c9e17be38e3cfc12eae11 (diff)
[XLA] Sink layout sensitivity from CSE into HloInstruction::Identical, and make it the default.
PiperOrigin-RevId: 184598903
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h27
3 files changed, 31 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 7feda2b3b0..279edd4ba8 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -119,9 +119,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
equivalent_instructions;
for (HloInstruction* user : operand->users()) {
if (user != instruction &&
- user->Identical(*instruction, eq_instructions, eq_computations) &&
- (!is_layout_sensitive_ ||
- ShapeUtil::Equal(user->shape(), instruction->shape()))) {
+ user->Identical(*instruction, eq_instructions, eq_computations,
+ is_layout_sensitive_)) {
equivalent_instructions.push_back(user);
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index fac6b43405..277648f072 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1612,7 +1612,8 @@ bool HloInstruction::HasConstantOperand() const {
bool HloInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
+ eq_computations,
+ const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
// Perform opcode specific checks.
switch (opcode()) {
// The result of these instructions only depend upon their opcode and
@@ -1671,7 +1672,7 @@ bool HloInstruction::IdenticalSlowPath(
return parameter_number() == other.parameter_number() &&
// Check the shape too because `this` and `other` may be in
// different HloComputations.
- ShapeUtil::Compatible(shape(), other.shape());
+ eq_shapes(shape(), other.shape());
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
@@ -1727,18 +1728,18 @@ bool HloInstruction::IdenticalSlowPath(
protobuf_util::ProtobufEquals(window(), other.window());
case HloOpcode::kReshape:
- return ShapeUtil::Compatible(shape(), other.shape());
+ return eq_shapes(shape(), other.shape());
// Transpose result is determined by the final shape and the permutation.
case HloOpcode::kTranspose:
- return ShapeUtil::Compatible(shape(), other.shape()) &&
+ return eq_shapes(shape(), other.shape()) &&
dimensions() == other.dimensions();
// Remaining instructions with special values.
case HloOpcode::kBitcast:
- return ShapeUtil::Equal(shape(), other.shape());
+ return eq_shapes(shape(), other.shape());
case HloOpcode::kBroadcast:
- return ShapeUtil::Compatible(shape(), other.shape()) &&
+ return eq_shapes(shape(), other.shape()) &&
dimensions() == other.dimensions();
case HloOpcode::kConcatenate:
return dimensions() == other.dimensions();
@@ -1752,10 +1753,10 @@ bool HloInstruction::IdenticalSlowPath(
slice_limits_ == other.slice_limits_ &&
slice_strides_ == other.slice_strides_;
case HloOpcode::kDynamicSlice:
- return ShapeUtil::Compatible(shape(), other.shape()) &&
+ return eq_shapes(shape(), other.shape()) &&
dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
case HloOpcode::kDynamicUpdateSlice:
- return ShapeUtil::Compatible(shape(), other.shape());
+ return eq_shapes(shape(), other.shape());
case HloOpcode::kCall:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index bce9ebdda8..50931c563a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -554,27 +554,36 @@ class HloInstruction {
}
// Returns true if "other" performs the same computation as this instruction.
- // Layout of the instructions' output array is not considered.
bool Identical(
const HloInstruction& other,
const std::function<bool(const HloInstruction*, const HloInstruction*)>&
eq_operands = std::equal_to<const HloInstruction*>(),
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations = std::equal_to<const HloComputation*>()) const {
+ eq_computations = std::equal_to<const HloComputation*>(),
+ bool layout_sensitive = true) const {
// An instruction is always identical to itself.
if (this == &other) {
return true;
}
- // Identical instruction must have the same opcode and identical operands.
- // In general, there is no need to check shape because shape is inferred
- // from the shape of the operands.
+ // Identical instruction must have the same opcode, shape, and identical
+ // operands.
if (opcode() != other.opcode()) {
return false;
}
+ auto eq_shapes = layout_sensitive
+ ? [](const Shape& a,
+ const Shape& b) { return ShapeUtil::Equal(a, b); }
+ : [](const Shape& a, const Shape& b) {
+ return ShapeUtil::Compatible(a, b);
+ };
+ if (!eq_shapes(shape(), other.shape())) {
+ return false;
+ }
if (operands().size() != other.operands().size()) {
return false;
}
+
// Use an explicit loop rather than ContainerEquals, because copying around
// std::functions may be too expensive in some cases.
for (size_t i = 0; i < operands().size(); ++i) {
@@ -583,7 +592,7 @@ class HloInstruction {
}
}
- return IdenticalSlowPath(other, eq_computations);
+ return IdenticalSlowPath(other, eq_computations, eq_shapes);
}
// Returns whether the instruction has a constant operand.
@@ -1232,10 +1241,14 @@ class HloInstruction {
class FusionReusesParamElements;
// See comments on Identical().
+ // eq_shapes() is used to check shapes for equality, and would normally be
+ // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
+ // whether we want a layout-sensitive check or not.
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const;
+ eq_computations,
+ const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(