diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-02 18:32:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-02 18:36:33 -0700 |
commit | 274e9ed51ea6cc09a0b5fc1cee4756ac0e9aa525 (patch) | |
tree | 35b43ee92bfc1689c3deeec03fa13c61ab5c8b1f /tensorflow/compiler/xla/service/hlo_cost_analysis.h | |
parent | fbc5460b0a5c2daa477c68477b9330424054ba25 (diff) |
[TF:XLA] Add a const HLO visitor.
Use it in the HLO cost analysis pass.
PiperOrigin-RevId: 174411043
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.h | 98 |
1 files changed, 50 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 93b1b3eb20..8074868e37 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -34,7 +34,7 @@ namespace xla { // the computation cost of the instruction, and the values are accumulated // during the traversal for the entire graph. We treat normal floating point // operations separately from transcendental operations. -class HloCostAnalysis : public DfsHloVisitor { +class HloCostAnalysis : public ConstDfsHloVisitor { public: // Each HLO is associated to a vector of properties with the indices given // below. Sub-classes can add further properties. @@ -49,54 +49,56 @@ class HloCostAnalysis : public DfsHloVisitor { using ShapeSizeFunction = std::function<int64(const Shape&)>; explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); - Status HandleElementwiseUnary(HloInstruction* hlo) override; - Status HandleElementwiseBinary(HloInstruction* hlo) override; - Status HandleConstant(HloInstruction* constant) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleSelect(HloInstruction* select) override; - Status HandleCompare(HloInstruction* compare) override; - Status HandleClamp(HloInstruction* clamp) override; - Status HandleReducePrecision(HloInstruction* hlo) override; - Status HandleConcatenate(HloInstruction* concatenate) override; - Status HandleSend(HloInstruction* send) override; - Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert) override; - Status HandleCopy(HloInstruction* copy) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleInfeed(HloInstruction* infeed) override; - Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleRng(HloInstruction* random) override; - Status HandleReverse(HloInstruction* reverse) override; - Status HandleSort(HloInstruction* sort) override; - Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleElementwiseUnary(const HloInstruction* hlo) override; + Status HandleElementwiseBinary(const HloInstruction* hlo) override; + Status HandleConstant(const HloInstruction* constant) override; + Status HandleGetTupleElement( + const HloInstruction* get_tuple_element) override; + Status HandleSelect(const HloInstruction* select) override; + Status HandleCompare(const HloInstruction* compare) override; + Status HandleClamp(const HloInstruction* clamp) override; + Status HandleReducePrecision(const HloInstruction* hlo) override; + Status HandleConcatenate(const HloInstruction* concatenate) override; + Status HandleSend(const HloInstruction* send) override; + Status HandleRecv(const HloInstruction* recv) override; + Status HandleConvert(const HloInstruction* convert) override; + Status HandleCopy(const HloInstruction* copy) override; + Status HandleDot(const HloInstruction* dot) override; + Status HandleConvolution(const HloInstruction* convolution) override; + Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleInfeed(const HloInstruction* infeed) override; + Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleRng(const HloInstruction* random) override; + Status HandleReverse(const HloInstruction* reverse) override; + Status HandleSort(const HloInstruction* sort) override; + Status HandleParameter(const HloInstruction* parameter) override; + Status HandleReduce(const HloInstruction* reduce) override; + Status HandleBatchNormTraining( + const HloInstruction* batch_norm_training) override; Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleSlice(HloInstruction* slice) override; - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + const HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; + Status HandleFusion(const HloInstruction* fusion) override; + Status HandleCall(const HloInstruction* call) override; + Status HandleCustomCall(const HloInstruction* custom_call) override; + Status HandleSlice(const HloInstruction* slice) override; + Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; - Status HandleReduceWindow(HloInstruction* reduce_window) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandlePad(HloInstruction* pad) override; - Status HandleReshape(HloInstruction* reshape) override; - Status HandleTranspose(HloInstruction* transpose) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status FinishVisit(HloInstruction* root) override; - - Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* hlo) override; + const HloInstruction* dynamic_update_slice) override; + Status HandleTuple(const HloInstruction* tuple) override; + Status HandleMap(const HloInstruction* map) override; + Status HandleReduceWindow(const HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(const HloInstruction* instruction) override; + Status HandleBitcast(const HloInstruction* bitcast) override; + Status HandleBroadcast(const HloInstruction* broadcast) override; + Status HandlePad(const HloInstruction* pad) override; + Status HandleReshape(const HloInstruction* reshape) override; + Status HandleTranspose(const HloInstruction* transpose) override; + Status HandleWhile(const HloInstruction* xla_while) override; + Status FinishVisit(const HloInstruction* root) override; + + Status Preprocess(const HloInstruction* hlo) override; + Status Postprocess(const HloInstruction* hlo) override; // Set the rates used to calculate the time taken by the computation. These // need to be set before visiting starts. @@ -145,7 +147,7 @@ class HloCostAnalysis : public DfsHloVisitor { const ShapeSizeFunction* shape_size = nullptr); // Utility function to handle all element-wise operations. - Status HandleElementwiseOp(HloInstruction* hlo_instruction); + Status HandleElementwiseOp(const HloInstruction* hlo_instruction); // Returns the default value if the key is not present in the // properties. Otherwise, returns the value that the key maps to from the |