aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 22:39:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 22:41:36 -0700
commit5f3281dd4a0d72cb51064599118088167878e0ef (patch)
tree87b01224ffb88e3dfcb290d256870ce85c71b186
parentebcc765d70257061bcbf1f50377e54cc9c91d388 (diff)
Split out HloGetTupleIndexInstruction and HloReducePrecisionInstruction as subclasses from HloInstruction.
PiperOrigin-RevId: 200337508
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc76
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h37
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc81
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h55
4 files changed, 174 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index aafb3b9dfd..39662d1735 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -233,6 +233,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction = CreateParameter(proto.parameter_number(), proto.shape(),
proto.name());
break;
+ case HloOpcode::kGetTupleElement:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction = CreateGetTupleElement(proto.shape(), operands(0),
+ proto.tuple_index());
+ break;
+ case HloOpcode::kReducePrecision:
+ instruction =
+ CreateReducePrecision(proto.shape(), operands(0),
+ proto.exponent_bits(), proto.mantissa_bits());
+ break;
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -260,11 +270,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(!proto.name().empty());
instruction->SetAndSanitizeName(proto.name());
-
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
- instruction->tuple_index_ = proto.tuple_index();
if (proto.has_window()) {
instruction->window_ = MakeUnique<Window>(proto.window());
}
@@ -278,8 +286,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
}
- instruction->exponent_bits_ = proto.exponent_bits();
- instruction->mantissa_bits_ = proto.mantissa_bits();
for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) {
instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size);
}
@@ -334,12 +340,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
- CHECK(ShapeUtil::IsTuple(operand->shape()));
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape));
- instruction->tuple_index_ = index;
- instruction->AppendOperand(operand);
- return instruction;
+ return MakeUnique<HloGetTupleElementInstruction>(shape, operand, index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
@@ -520,12 +521,8 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape));
- instruction->AppendOperand(operand);
- instruction->exponent_bits_ = exponent_bits;
- instruction->mantissa_bits_ = mantissa_bits;
- return instruction;
+ return MakeUnique<HloReducePrecisionInstruction>(
+ shape, operand, exponent_bits, mantissa_bits);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -1041,6 +1038,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kParameter:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kReducePrecision:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1127,11 +1126,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
- case HloOpcode::kReducePrecision:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
- mantissa_bits_);
- break;
case HloOpcode::kConvolution:
CHECK_EQ(new_operands.size(), 2);
clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
@@ -1147,10 +1141,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateCrossReplicaSum(shape, new_operands, to_apply(),
replica_group_ids_, cross_replica_sum_barrier_);
break;
- case HloOpcode::kGetTupleElement:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateGetTupleElement(shape, new_operands[0], tuple_index());
- break;
case HloOpcode::kPad:
CHECK_EQ(new_operands.size(), 2);
clone =
@@ -1297,11 +1287,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
return hlo;
}
-int64 HloInstruction::tuple_index() const {
- CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
- return tuple_index_;
-}
-
const HloInstruction* HloInstruction::operand(int64 i) const {
return operands_[i];
}
@@ -1464,11 +1449,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kGenerateToken:
return false;
- // A reduce-precision operation is determined by the bit sizes.
- case HloOpcode::kReducePrecision:
- return exponent_bits() == other.exponent_bits() &&
- mantissa_bits() == other.mantissa_bits();
-
// Convolution has a window and dimensions.
case HloOpcode::kConvolution:
return protobuf_util::ProtobufEquals(window(), other.window()) &&
@@ -1497,8 +1477,6 @@ bool HloInstruction::IdenticalSlowPath(
protobuf_util::ProtobufEquals(window(), other.window());
// Remaining instructions with special values.
- case HloOpcode::kGetTupleElement:
- return tuple_index() == other.tuple_index();
case HloOpcode::kPad:
return protobuf_util::ProtobufEquals(padding_config(),
other.padding_config());
@@ -1555,6 +1533,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kParameter:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kReducePrecision:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -2044,9 +2024,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}
}
- if (opcode() == HloOpcode::kGetTupleElement) {
- extra.push_back(StrCat("index=", tuple_index()));
- }
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
@@ -2066,10 +2043,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
}
- if (opcode() == HloOpcode::kReducePrecision) {
- extra.push_back(StrCat("exponent_bits=", exponent_bits_));
- extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
- }
if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
"\", entry=", operand_side_metadata_->ToString(),
@@ -2127,7 +2100,6 @@ HloInstructionProto HloInstruction::ToProto() const {
}
}
- proto.set_tuple_index(tuple_index_);
if (window_ != nullptr) {
*proto.mutable_window() = *window_;
}
@@ -2147,8 +2119,6 @@ HloInstructionProto HloInstruction::ToProto() const {
}
}
- proto.set_exponent_bits(exponent_bits_);
- proto.set_mantissa_bits(mantissa_bits_);
for (int64 slice_size : dynamic_slice_sizes_) {
proto.add_dynamic_slice_sizes(slice_size);
}
@@ -3186,4 +3156,16 @@ RandomDistribution HloInstruction::random_distribution() const {
int64 HloInstruction::parameter_number() const {
return Cast<HloParameterInstruction>(this)->parameter_number();
}
+
+int64 HloInstruction::tuple_index() const {
+ return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
+}
+
+int32 HloInstruction::exponent_bits() const {
+ return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
+}
+
+int32 HloInstruction::mantissa_bits() const {
+ return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 245c9e56f1..a206cdab27 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -876,11 +876,6 @@ class HloInstruction {
template <typename HloInstructionPtr>
Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
- // Returns the tuple index associated with this instruction.
- //
- // Precondition: opcode() == HloOpcode::kGetTupleElement
- int64 tuple_index() const;
-
// Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
// (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
@@ -1078,22 +1073,6 @@ class HloInstruction {
return dynamic_slice_sizes_;
}
- // Returns the number of exponent bits for a reduce-precision node.
- //
- // Precondition: opcode() == HloOpcode::kReducePrecision
- int32 exponent_bits() const {
- CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
- return exponent_bits_;
- }
-
- // Returns the number of mantissa bits for a reduce-precision node.
- //
- // Precondition: opcode() == HloOpcode::kReducePrecision
- int32 mantissa_bits() const {
- CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
- return mantissa_bits_;
- }
-
// Returns data on the window in a windowed operation such as
// convolution.
const Window& window() const {
@@ -1439,6 +1418,15 @@ class HloInstruction {
// Delegates to HloParameterInstruction::parameter_number.
int64 parameter_number() const;
+
+ // Delegates to HloGetTupleElementInstruction::tuple_index.
+ int64 tuple_index() const;
+
+ // Returns the number of exponent bits for a reduce-precision node.
+ int32 exponent_bits() const;
+
+ // Returns the number of mantissa bits for a reduce-precision node.
+ int32 mantissa_bits() const;
// Old methods kept for smooth subclassing transition END.
// Returns the group ids of each replica for CrossReplicaSum op.
@@ -1573,9 +1561,6 @@ class HloInstruction {
// Result shape of this instruction.
Shape shape_;
- // Constant index, only present for kGetTupleElement.
- int64 tuple_index_ = -1;
-
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
@@ -1588,10 +1573,6 @@ class HloInstruction {
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_window_bounds_;
- // The bit sizes for a reduce-precision operation.
- int32 exponent_bits_ = 0;
- int32 mantissa_bits_ = 0;
-
// Describes the [start, start + size) range size for a dynamic slice
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 22c8707e37..d326d5d009 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1203,4 +1203,85 @@ HloParameterInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
return MakeUnique<HloParameterInstruction>(parameter_number_, shape, name());
}
+
+HloGetTupleElementInstruction::HloGetTupleElementInstruction(
+ const Shape& shape, HloInstruction* operand, int64 index)
+ : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
+ CHECK(ShapeUtil::IsTuple(operand->shape()));
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_tuple_index(tuple_index_);
+ return proto;
+}
+
+std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("index=", tuple_index())};
+}
+
+bool HloGetTupleElementInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloGetTupleElementInstruction&>(other);
+ return tuple_index() == casted_other.tuple_index();
+}
+
+std::unique_ptr<HloInstruction>
+HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloGetTupleElementInstruction>(shape, new_operands[0],
+ tuple_index());
+}
+
+HloReducePrecisionInstruction::HloReducePrecisionInstruction(
+ const Shape& shape, HloInstruction* operand, const int exponent_bits,
+ const int mantissa_bits)
+ : HloInstruction(HloOpcode::kReducePrecision, shape),
+ exponent_bits_(exponent_bits),
+ mantissa_bits_(mantissa_bits) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_exponent_bits(exponent_bits_);
+ proto.set_mantissa_bits(mantissa_bits_);
+ return proto;
+}
+
+std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("exponent_bits=", exponent_bits_),
+ StrCat("mantissa_bits=", mantissa_bits_)};
+}
+
+bool HloReducePrecisionInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloReducePrecisionInstruction&>(other);
+ // A reduce-precision operation is determined by the bit sizes.
+ return exponent_bits() == casted_other.exponent_bits() &&
+ mantissa_bits() == casted_other.mantissa_bits();
+}
+
+std::unique_ptr<HloInstruction>
+HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloReducePrecisionInstruction>(
+ shape, new_operands[0], exponent_bits(), mantissa_bits());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index bab2a48166..6749d87555 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -667,6 +667,61 @@ class HloParameterInstruction : public HloInstruction {
int64 parameter_number_ = 0;
};
+class HloGetTupleElementInstruction : public HloInstruction {
+ public:
+ explicit HloGetTupleElementInstruction(const Shape& shape,
+ HloInstruction* operand, int64 index);
+ // Returns the tuple index associated with this instruction.
+ int64 tuple_index() const { return tuple_index_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ int64 tuple_index_ = -1;
+};
+
+class HloReducePrecisionInstruction : public HloInstruction {
+ public:
+ explicit HloReducePrecisionInstruction(const Shape& shape,
+ HloInstruction* operand,
+ const int exponent_bits,
+ const int mantissa_bits);
+ // Returns the number of exponent bits for a reduce-precision node.
+ int32 exponent_bits() const { return exponent_bits_; }
+ // Returns the number of mantissa bits for a reduce-precision node.
+ int32 mantissa_bits() const { return mantissa_bits_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The bit sizes for a reduce-precision operation.
+ int32 exponent_bits_ = 0;
+ int32 mantissa_bits_ = 0;
+};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_