aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-05-17 15:08:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 15:11:21 -0700
commit8e2ff05d31118724eb21c48b98cd45c64884e13c (patch)
treee76906a06e6ca9b7462ac1d70627dee31fda6e78
parentfacd8f50733a398cc0ee08dfe76ad6b4f9e61817 (diff)
[XLA] Remove eq_shapes from Identical SlowPath since it is already checked in
Identical. PiperOrigin-RevId: 197058888
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc47
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
2 files changed, 19 insertions, 42 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d2fbc83ec0..66ff111b04 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1689,24 +1689,27 @@ bool HloInstruction::HasConstantOperand() const {
bool HloInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations,
- const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
+ eq_computations) const {
// Perform opcode specific checks.
switch (opcode()) {
// The result of these instructions only depend upon their opcode and
// operands.
case HloOpcode::kAbs:
case HloOpcode::kAtan2:
- case HloOpcode::kRoundNearestAfz:
case HloOpcode::kAdd:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kComplex:
+ case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kDivide:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
@@ -1730,6 +1733,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
+ case HloOpcode::kReshape:
+ case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
@@ -1741,6 +1746,12 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kTuple:
return true;
+ // Broadcast, Concatenate, and Transpose need the same dimensions field.
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kTranspose:
+ return dimensions() == other.dimensions();
+
case HloOpcode::kFusion:
return fusion_kind() == other.fusion_kind() &&
eq_computations(fused_instructions_computation(),
@@ -1753,10 +1764,7 @@ bool HloInstruction::IdenticalSlowPath(
return false;
case HloOpcode::kParameter:
- return parameter_number() == other.parameter_number() &&
- // Check the shape too because `this` and `other` may be in
- // different HloComputations.
- eq_shapes(shape(), other.shape());
+ return parameter_number() == other.parameter_number();
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
@@ -1768,12 +1776,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kConstant:
return literal() == other.literal();
- // A convert result is determined by the primitive type that the operand is
- // converted into.
- case HloOpcode::kConvert:
- case HloOpcode::kBitcastConvert:
- return shape().element_type() == other.shape().element_type();
-
// A reduce-precision operation is determined by the bit sizes.
case HloOpcode::kReducePrecision:
return exponent_bits() == other.exponent_bits() &&
@@ -1816,22 +1818,8 @@ bool HloInstruction::IdenticalSlowPath(
eq_computations(scatter(), other.scatter()) &&
protobuf_util::ProtobufEquals(window(), other.window());
- case HloOpcode::kReshape:
- return eq_shapes(shape(), other.shape());
-
- // Transpose result is determined by the final shape and the permutation.
- case HloOpcode::kTranspose:
- return eq_shapes(shape(), other.shape()) &&
- dimensions() == other.dimensions();
// Remaining instructions with special values.
- case HloOpcode::kBitcast:
- return eq_shapes(shape(), other.shape());
- case HloOpcode::kBroadcast:
- return eq_shapes(shape(), other.shape()) &&
- dimensions() == other.dimensions();
- case HloOpcode::kConcatenate:
- return dimensions() == other.dimensions();
case HloOpcode::kGetTupleElement:
return tuple_index() == other.tuple_index();
case HloOpcode::kPad:
@@ -1841,11 +1829,6 @@ bool HloInstruction::IdenticalSlowPath(
return slice_starts_ == other.slice_starts_ &&
slice_limits_ == other.slice_limits_ &&
slice_strides_ == other.slice_strides_;
- case HloOpcode::kDynamicSlice:
- return eq_shapes(shape(), other.shape()) &&
- dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
- case HloOpcode::kDynamicUpdateSlice:
- return eq_shapes(shape(), other.shape());
case HloOpcode::kCall:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 0831a54a9f..db78539c63 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -746,10 +746,8 @@ class HloInstruction {
if (opcode() != other.opcode()) {
return false;
}
- using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
- EqShapeFuncType eq_shapes =
- layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
- if (!eq_shapes(shape(), other.shape())) {
+ if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
+ : ShapeUtil::Compatible(shape(), other.shape()))) {
return false;
}
if (operands().size() != other.operands().size()) {
@@ -764,7 +762,7 @@ class HloInstruction {
}
}
- return IdenticalSlowPath(other, eq_computations, eq_shapes);
+ return IdenticalSlowPath(other, eq_computations);
}
// Returns whether the instruction has a constant operand.
@@ -1497,14 +1495,10 @@ class HloInstruction {
class FusionReusesParamElements;
// See comments on Identical().
- // eq_shapes() is used to check shapes for equality, and would normally be
- // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
- // whether we want a layout-sensitive check or not.
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations,
- const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
+ eq_computations) const;
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(