aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-02 22:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 22:16:19 -0700
commit7bb2d57b0b051d1cf8dd74d3276bf5a452774172 (patch)
treed5b07beacebcc425454978eb87ffecfe728d4281 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
parent8a7f5c47dcb71deb71df4a72f3cf829904c5a28e (diff)
Rewrite CopyInsertion to use module-scoped HloAliasAnalysis. The net effect (number of copies inserted) is roughly similar to the existing implementation, but the new implementation is much more general. The new implementation can handle entry argument buffer reuse with minimal modification, for example.
Some unnecessary copies are still added due to deficiencies in buffer assignment (b/62548313), but these can be removed when buffer assignment also uses HloAliasAnalysis. Also address a few issues uncovered with this cl: (1) For inplace dynamic slice in llvm backends, truncate do not wrap the slice. This matches the behavior of the non-inplace variant. (2) Disable SelectBetweenPredTuples test on GPU. The test introduces top-level buffer ambiguity which is not tolerated by the gpu backend. (3) When deserializing HLO form a proto, do not uniquify instruction names in fused computations. (4) In dataflow analysis, don't deallocate deleted HloValues during propagation. (5) In dataflow analysis, fix issue with live_out_of_computation property. PiperOrigin-RevId: 174423881
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h22
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 207e553bf7..49b1343873 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -126,13 +126,16 @@ class HloDataflowAnalysis {
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
bool is_phi = false);
- // Delete the HloValue with the given ID.
- void DeleteHloValue(HloValue::Id value_id);
+ // Mark the HloValue with the given ID for deletion.
+ void MarkValueForDeletion(HloValue::Id value_id);
+
+ // Delete all HloValues marked for deletion. Should be called after
+ // propagation is complete.
+ void DeleteMarkedValues();
// Constructs and initializes the InstructionValueSets of all instructions to
// contain exactly the HloValues defined by each instruction. These values can
- // then propagated throughout the HLO graph by calling
- // UpdateInstructionsAndPropagate.
+ // then propagated throughout the HLO graph by calling Propagate.
Status InitializeInstructionValueSets();
// Updates the value set of the given instruction based on the values flowing
@@ -150,10 +153,8 @@ class HloDataflowAnalysis {
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
- // Update the value sets of the given instructions and propagate the
- // changes to fixed point.
- void UpdateInstructionsAndPropagate(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
+ // Propagate the dataflow through the module.
+ void Propagate();
// Return the result of the SSA Phi function applied to the given inputs at
// the given instruction. If skip_top_level is true, then the top level of the
@@ -189,6 +190,11 @@ class HloDataflowAnalysis {
// A map from instruction to InstructionValueSet.
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
+ // Values marked for deletion during construction. We don't delete them
+ // immediately because references to them may still remain in ValueSets. After
+ // construction, these values are deleted.
+ std::vector<HloValue::Id> value_ids_to_delete_;
+
// A vector containing all HloValues sorted by HloValue::Id.
std::vector<const HloValue*> values_vector_;