diff options
author | 2017-11-09 14:48:37 -0800 | |
---|---|---|
committer | 2017-11-10 16:14:40 -0800 | |
commit | 51895becce83ef4dc8bac263377d158fc50e4d53 (patch) | |
tree | 81ba187df17b1c2a3b0784f776f68fc43060f544 /tensorflow/compiler/xla/service/tuple_points_to_analysis.cc | |
parent | 2bb46f6376c35ec86279e80a24dc06b068b41556 (diff) |
Change for asynchronous Send and Recv by splitting Send into {Send, SendDone}
and Recv into {Recv, RecvDone}. See operation_semantics.md for the updated
semantics.
PiperOrigin-RevId: 175216012
Diffstat (limited to 'tensorflow/compiler/xla/service/tuple_points_to_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/tuple_points_to_analysis.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index df537bd7c1..a1f9451dd4 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -253,6 +253,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone aliases its input (Recv) tuple element {0} to its output. + PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + const PointsToSet& operand_points_to_set = + GetPointsToSet(recv_done->operand(0)); + + // Recursively copy the points to set of the operand tuple {0}. + points_to_set.ForEachMutableElement( + [this, &points_to_set, &operand_points_to_set]( + const ShapeIndex& index, PointsToSet::BufferList* buffers) { + ShapeIndex src_index({0}); + for (auto element : index) { + src_index.push_back(element); + } + *buffers = operand_points_to_set.element(src_index); + for (auto& tuple_source : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(index, tuple_source); + } + }); + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { + // Send creates a tuple of {aliased operand, U32 context}. + PointsToSet& points_to_set = CreateEmptyPointsToSet(send); + + // Creates the points to set for the tuple and its element at {1}. + auto top_buffer = points_to_set.mutable_element(ShapeIndex({})); + top_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({}))); + points_to_set.add_tuple_source({}, send); + + auto context_buffer = points_to_set.mutable_element(ShapeIndex({1})); + context_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + + // Recursively copy the points to set of the operand to output tuple {0}. + const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); + operand_points_to_set.ForEachElement( + [&points_to_set, &operand_points_to_set]( + const ShapeIndex& src_index, + const PointsToSet::BufferList& points_to) { + ShapeIndex target_index({0}); + for (auto element : src_index) { + target_index.push_back(element); + } + *points_to_set.mutable_element(target_index) = points_to; + + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + }); + + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); |