aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-30 11:22:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-30 11:26:16 -0700
commita3ba225d5b327013709a1732688bfd4346b3c86e (patch)
tree0e7e34b7dd509d5a1665617560c7e7effc47a7a8 /tensorflow/core/grappler/costs/op_level_cost_estimator.h
parent34a29fc3b216b5dbcc0a36b76f54bd29e0f7d433 (diff)
Add BatchMatMul execution cost prediction
PiperOrigin-RevId: 157487507
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 5e0a97653c..d234880919 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -80,6 +80,8 @@ class OpLevelCostEstimator {
int64 CountMatMulOperations(const OpInfo& op_features,
MatMulDimensions* mat_mul,
bool* found_unknown_shapes) const;
+ int64 CountBatchMatMulOperations(const OpInfo& op_features,
+ bool* found_unknown_shapes) const;
int64 CountConv2DBackPropInputOperations(const OpInfo& op_features,
ConvolutionDimensions* conv_info,
bool* found_unknown_shapes) const;
@@ -116,6 +118,7 @@ class OpLevelCostEstimator {
Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const;
Costs PredictMatMul(const OpInfo& op_features) const;
Costs PredictNoOp(const OpInfo& op_features) const;
+ Costs PredictBatchMatMul(const OpInfo& op_features) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -135,6 +138,9 @@ class OpLevelCostEstimator {
protected:
typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
std::map<string, CostImpl> device_cost_impl_;
+
+ private:
+ friend class OpLevelCostEstimatorTest;
};
} // end namespace grappler