aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.h
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-07-09 16:07:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 16:11:41 -0700
commit3da378059e7595700a89e4d4cf29a48c5748ea18 (patch)
tree462e759f8affcaec3ee83fe317d6dd03ba930c9e /tensorflow/compiler/xla/service/hlo_computation.h
parentbc2674d09efbd87ae81ae41b81f1d152f37fac2a (diff)
[XLA] BFloat16 propagation: add de-aliasing copies before while loop inputs.
The while loop input and output alias each other, so as long as an input is also used by other ops that could not use BF16, the propagation pass could not change such an input/ouput to BF16 even if all uses in the while loop could use BF16. Add copies for each while loop operand. This increases the chance to propagate BF16 through the while loop; if some of these copies do not help, they will remain same-shape copies and be removed at the end. This can sometimes increase HBM usage because both BF16 and F32 copies are alive, and can sometimes reduce HBM usage. PiperOrigin-RevId: 203848348
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h15
1 files changed, 13 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index c1c3e79ebc..abc1da4da3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
+#include <functional>
#include <list>
#include <memory>
#include <string>
@@ -254,6 +255,14 @@ class HloComputation {
const ShapeTree<bool>* indices_to_copy = nullptr,
ShapeTree<HloInstruction*>* copies_added = nullptr);
+ // As above, but uses a custom function to copy the leaf nodes, which could
+ // create alternative HLOs other than kCopy, or even pass-throughs.
+ StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
+
// Computes and returns the ProgramShape of this computation (shape of
// parameters and result with layout).
ProgramShape ComputeProgramShape() const;
@@ -378,8 +387,10 @@ class HloComputation {
// Internal helper for recursive copying of an instruction. Creates and
// returns a deep copy of the given instruction.
StatusOr<HloInstruction*> DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;