aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-20 15:19:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-20 15:22:40 -0700
commit60a0e2f5261cf72da4e4d8e65b56b695d611b984 (patch)
treefa80d75b322e70f969ba4ae8e9bdfc49da6550ea /tensorflow/compiler/xla/service/layout_assignment.h
parentb133f8c70622e52f19631fd93d4b87ee21c52ac6 (diff)
Do not force default layout when there is no need to.
Allow the inner computations to negotiate a root and parameter layouts different from default. PiperOrigin-RevId: 193731341
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.h')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h65
1 files changed, 53 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index ae4986d6ad..8b4e07995a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface {
int64 operand_no);
private:
+ // Initializes the layout assignment object for a new Run() call.
+ Status Init();
+
// Adds constraints which must be satisfied for correctness on all
// backends. Called once prior to propagating constraints.
- Status AddMandatoryConstraints(
- const ComputationLayout& computation_layout,
- const ChannelLayoutConstraints* channel_constraints,
- HloComputation* computation, LayoutConstraints* constraints);
+ Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
+ ChannelLayoutConstraints* channel_constraints,
+ HloComputation* computation,
+ LayoutConstraints* constraints);
// This method can be overridden to add backend-specific constraints to the
// layout of the instructions of a computation. This method is called after
@@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface {
}
// Construct contraints and assign layouts to all instructions in the
- // computation satisfying the given ComputationLayout. Layouts constraints are
- // added, then propagated until all LogicalBuffers in the computation are
- // constrained.
- Status RunOnComputation(const ComputationLayout& computation_layout,
+ // computation satisfying the given ComputationLayout, if not nullptr.
+ // Otherwise the ComputationLayout will be calculated by propagating the
+ // computation instruction contraints.
+ // Layouts constraints are added, then propagated until all LogicalBuffers in
+ // the computation are constrained.
+ Status RunOnComputation(ComputationLayout* computation_layout,
const TuplePointsToAnalysis& points_to_analysis,
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints);
@@ -402,6 +408,25 @@ class LayoutAssignment : public HloPassInterface {
// necessary conditions.
Status CheckLayouts(HloModule* module);
+ // Computes the ComputationLayout of the given computation based of the
+ // layouts assigned to parameters and root instruction, and inserts it to the
+ // computation_layouts_ map.
+ Status CalculateComputationLayout(HloComputation* computation);
+
+ // Clears all the layouts which can be cleared within a computation.
+ Status ClearComputationLayouts(HloComputation* computation);
+
+ // Clears the side effects of a previous pass, like added copy instructions.
+ Status ClearPreviousPassSideEffects(HloModule* module);
+
+ // Propagates the layouts computed by the layout assignment pass on the given
+ // computation, to the computation layout passed in to this API.
+ // This API propagates missing layout, and also checks that the caller
+ // specified have been respected, by comparing those with the parameters and
+ // root computation instruction.
+ Status PropagateComputationLayouts(HloComputation* computation,
+ ComputationLayout* computation_layout);
+
ComputationLayout* entry_computation_layout_;
protected:
@@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface {
// Creates and returns a copy of the given instruction with a different
// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
// instruction producing the copy is returned.
- static StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ StatusOr<HloInstruction*> CreateCopyWithNewLayout(
const Shape& shape_with_layout, HloInstruction* instruction);
// Creates a copy of the given operand if the operand's layout does not match
// the given layout. This copy replaces the use in the given instruction.
// Tuple operands will be deep-copied.
- static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
- HloInstruction* instruction,
- int64 operand_no);
+ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no);
+
+ // Registers a copy instruction added by the layout assignment pass.
+ void RegisterAddedCopy(HloInstruction* copy) {
+ CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
+ added_copies_.insert(copy);
+ }
+
+ // Adds a copy for the operand of an instruction, unless such operand is
+ // already a copy, and has a single user (which is forcibly the instruction
+ // itself).
+ Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
// Map containing the layouts of all computations assigned so
// far. Computations are handled in a topological sort where computations are
// handled before their caller instructions so the layouts of caller
// instructions can be set to match the computation.
std::map<HloComputation*, ComputationLayout> computation_layouts_;
+
+ // Every copy added to the module by the layout assignment pass is registered
+ // here.
+ tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
+
ChannelLayoutConstraints* channel_layout_constraints_;
};