diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-30 11:22:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-30 11:26:16 -0700 |
commit | a3ba225d5b327013709a1732688bfd4346b3c86e (patch) | |
tree | 0e7e34b7dd509d5a1665617560c7e7effc47a7a8 /tensorflow/core/grappler/costs/op_level_cost_estimator.h | |
parent | 34a29fc3b216b5dbcc0a36b76f54bd29e0f7d433 (diff) |
Add BatchMatMul execution cost prediction
PiperOrigin-RevId: 157487507
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 5e0a97653c..d234880919 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -80,6 +80,8 @@ class OpLevelCostEstimator { int64 CountMatMulOperations(const OpInfo& op_features, MatMulDimensions* mat_mul, bool* found_unknown_shapes) const; + int64 CountBatchMatMulOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; int64 CountConv2DBackPropInputOperations(const OpInfo& op_features, ConvolutionDimensions* conv_info, bool* found_unknown_shapes) const; @@ -116,6 +118,7 @@ class OpLevelCostEstimator { Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const; Costs PredictMatMul(const OpInfo& op_features) const; Costs PredictNoOp(const OpInfo& op_features) const; + Costs PredictBatchMatMul(const OpInfo& op_features) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. @@ -135,6 +138,9 @@ class OpLevelCostEstimator { protected: typedef std::function<Costs(const OpInfo& op_feature)> CostImpl; std::map<string, CostImpl> device_cost_impl_; + + private: + friend class OpLevelCostEstimatorTest; }; } // end namespace grappler |