diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 12:56:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 13:01:33 -0700 |
commit | 2325b1e1979694de07439fae7b4585eb6ed4f99a (patch) | |
tree | 29bb40f997a80d1ee13d871b7e4817c3b46b8aaa /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | c73964210ced86791c9231768315fa4652abc9ba (diff) |
[XLA] Cleanup Alltoall.
- Remove unused field 'cross_replica_sum_barrier' for Alltoall.
- Update cost analysis. There's no computation in Alltoall.
- Cleanup stale TODOs.
PiperOrigin-RevId: 209814190
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 19 |
1 files changed, 5 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 36fac4a266..345ca0053a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -371,11 +371,9 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier) + const std::vector<ReplicaGroup>& replica_groups) : HloInstruction(HloOpcode::kAllToAll, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { + replica_groups_(replica_groups) { for (auto operand : operands) { AppendOperand(operand); } @@ -390,9 +388,7 @@ bool HloAllToAllInstruction::IdenticalSlowPath( [](const ReplicaGroup& a, const ReplicaGroup& b) { return ContainersEqual(a.replica_ids(), b.replica_ids()); - }) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier(); + }); } std::unique_ptr<HloInstruction> @@ -400,8 +396,8 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* /*context*/) const { - return absl::make_unique<HloAllToAllInstruction>( - shape, new_operands, replica_groups(), cross_replica_sum_barrier()); + return absl::make_unique<HloAllToAllInstruction>(shape, new_operands, + replica_groups()); } std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl( @@ -415,10 +411,6 @@ std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl( result.push_back( StrCat("replica_groups={", Join(replica_group_str, ","), "}")); - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); - } - return result; } @@ -426,7 +418,6 @@ HloInstructionProto HloAllToAllInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_replica_groups() = {replica_groups_.begin(), replica_groups_.end()}; - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); return proto; } |