diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 5e0a97653c..d234880919 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -80,6 +80,8 @@ class OpLevelCostEstimator { int64 CountMatMulOperations(const OpInfo& op_features, MatMulDimensions* mat_mul, bool* found_unknown_shapes) const; + int64 CountBatchMatMulOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; int64 CountConv2DBackPropInputOperations(const OpInfo& op_features, ConvolutionDimensions* conv_info, bool* found_unknown_shapes) const; @@ -116,6 +118,7 @@ class OpLevelCostEstimator { Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const; Costs PredictMatMul(const OpInfo& op_features) const; Costs PredictNoOp(const OpInfo& op_features) const; + Costs PredictBatchMatMul(const OpInfo& op_features) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. @@ -135,6 +138,9 @@ class OpLevelCostEstimator { protected: typedef std::function<Costs(const OpInfo& op_feature)> CostImpl; std::map<string, CostImpl> device_cost_impl_; + + private: + friend class OpLevelCostEstimatorTest; }; } // end namespace grappler |