aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-02 18:32:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 18:36:33 -0700
commit274e9ed51ea6cc09a0b5fc1cee4756ac0e9aa525 (patch)
tree35b43ee92bfc1689c3deeec03fa13c61ab5c8b1f /tensorflow/compiler/xla/service/hlo_cost_analysis.h
parentfbc5460b0a5c2daa477c68477b9330424054ba25 (diff)
[TF:XLA] Add a const HLO visitor.
Use it in the HLO cost analysis pass. PiperOrigin-RevId: 174411043
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h98
1 files changed, 50 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 93b1b3eb20..8074868e37 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -34,7 +34,7 @@ namespace xla {
// the computation cost of the instruction, and the values are accumulated
// during the traversal for the entire graph. We treat normal floating point
// operations separately from transcendental operations.
-class HloCostAnalysis : public DfsHloVisitor {
+class HloCostAnalysis : public ConstDfsHloVisitor {
public:
// Each HLO is associated to a vector of properties with the indices given
// below. Sub-classes can add further properties.
@@ -49,54 +49,56 @@ class HloCostAnalysis : public DfsHloVisitor {
using ShapeSizeFunction = std::function<int64(const Shape&)>;
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
- Status HandleElementwiseUnary(HloInstruction* hlo) override;
- Status HandleElementwiseBinary(HloInstruction* hlo) override;
- Status HandleConstant(HloInstruction* constant) override;
- Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
- Status HandleSelect(HloInstruction* select) override;
- Status HandleCompare(HloInstruction* compare) override;
- Status HandleClamp(HloInstruction* clamp) override;
- Status HandleReducePrecision(HloInstruction* hlo) override;
- Status HandleConcatenate(HloInstruction* concatenate) override;
- Status HandleSend(HloInstruction* send) override;
- Status HandleRecv(HloInstruction* recv) override;
- Status HandleConvert(HloInstruction* convert) override;
- Status HandleCopy(HloInstruction* copy) override;
- Status HandleDot(HloInstruction* dot) override;
- Status HandleConvolution(HloInstruction* convolution) override;
- Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleInfeed(HloInstruction* infeed) override;
- Status HandleOutfeed(HloInstruction* outfeed) override;
- Status HandleRng(HloInstruction* random) override;
- Status HandleReverse(HloInstruction* reverse) override;
- Status HandleSort(HloInstruction* sort) override;
- Status HandleParameter(HloInstruction* parameter) override;
- Status HandleReduce(HloInstruction* reduce) override;
- Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
+ Status HandleElementwiseUnary(const HloInstruction* hlo) override;
+ Status HandleElementwiseBinary(const HloInstruction* hlo) override;
+ Status HandleConstant(const HloInstruction* constant) override;
+ Status HandleGetTupleElement(
+ const HloInstruction* get_tuple_element) override;
+ Status HandleSelect(const HloInstruction* select) override;
+ Status HandleCompare(const HloInstruction* compare) override;
+ Status HandleClamp(const HloInstruction* clamp) override;
+ Status HandleReducePrecision(const HloInstruction* hlo) override;
+ Status HandleConcatenate(const HloInstruction* concatenate) override;
+ Status HandleSend(const HloInstruction* send) override;
+ Status HandleRecv(const HloInstruction* recv) override;
+ Status HandleConvert(const HloInstruction* convert) override;
+ Status HandleCopy(const HloInstruction* copy) override;
+ Status HandleDot(const HloInstruction* dot) override;
+ Status HandleConvolution(const HloInstruction* convolution) override;
+ Status HandleCrossReplicaSum(const HloInstruction* crs) override;
+ Status HandleInfeed(const HloInstruction* infeed) override;
+ Status HandleOutfeed(const HloInstruction* outfeed) override;
+ Status HandleRng(const HloInstruction* random) override;
+ Status HandleReverse(const HloInstruction* reverse) override;
+ Status HandleSort(const HloInstruction* sort) override;
+ Status HandleParameter(const HloInstruction* parameter) override;
+ Status HandleReduce(const HloInstruction* reduce) override;
+ Status HandleBatchNormTraining(
+ const HloInstruction* batch_norm_training) override;
Status HandleBatchNormInference(
- HloInstruction* batch_norm_inference) override;
- Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
- Status HandleFusion(HloInstruction* fusion) override;
- Status HandleCall(HloInstruction* call) override;
- Status HandleCustomCall(HloInstruction* custom_call) override;
- Status HandleSlice(HloInstruction* slice) override;
- Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
+ const HloInstruction* batch_norm_inference) override;
+ Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override;
+ Status HandleFusion(const HloInstruction* fusion) override;
+ Status HandleCall(const HloInstruction* call) override;
+ Status HandleCustomCall(const HloInstruction* custom_call) override;
+ Status HandleSlice(const HloInstruction* slice) override;
+ Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
- HloInstruction* dynamic_update_slice) override;
- Status HandleTuple(HloInstruction* tuple) override;
- Status HandleMap(HloInstruction* map) override;
- Status HandleReduceWindow(HloInstruction* reduce_window) override;
- Status HandleSelectAndScatter(HloInstruction* instruction) override;
- Status HandleBitcast(HloInstruction* bitcast) override;
- Status HandleBroadcast(HloInstruction* broadcast) override;
- Status HandlePad(HloInstruction* pad) override;
- Status HandleReshape(HloInstruction* reshape) override;
- Status HandleTranspose(HloInstruction* transpose) override;
- Status HandleWhile(HloInstruction* xla_while) override;
- Status FinishVisit(HloInstruction* root) override;
-
- Status Preprocess(HloInstruction* hlo) override;
- Status Postprocess(HloInstruction* hlo) override;
+ const HloInstruction* dynamic_update_slice) override;
+ Status HandleTuple(const HloInstruction* tuple) override;
+ Status HandleMap(const HloInstruction* map) override;
+ Status HandleReduceWindow(const HloInstruction* reduce_window) override;
+ Status HandleSelectAndScatter(const HloInstruction* instruction) override;
+ Status HandleBitcast(const HloInstruction* bitcast) override;
+ Status HandleBroadcast(const HloInstruction* broadcast) override;
+ Status HandlePad(const HloInstruction* pad) override;
+ Status HandleReshape(const HloInstruction* reshape) override;
+ Status HandleTranspose(const HloInstruction* transpose) override;
+ Status HandleWhile(const HloInstruction* xla_while) override;
+ Status FinishVisit(const HloInstruction* root) override;
+
+ Status Preprocess(const HloInstruction* hlo) override;
+ Status Postprocess(const HloInstruction* hlo) override;
// Set the rates used to calculate the time taken by the computation. These
// need to be set before visiting starts.
@@ -145,7 +147,7 @@ class HloCostAnalysis : public DfsHloVisitor {
const ShapeSizeFunction* shape_size = nullptr);
// Utility function to handle all element-wise operations.
- Status HandleElementwiseOp(HloInstruction* hlo_instruction);
+ Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
// Returns the default value if the key is not present in the
// properties. Otherwise, returns the value that the key maps to from the