aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-28 13:11:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 13:13:48 -0700
commit480ac84aa8390e19a54bd2feef3a6069d959bb4e (patch)
tree615d3b34bae3554d50c3f6c64c25d9001e342a59 /tensorflow/core/grappler/costs/op_level_cost_estimator.h
parent560ef036727c871bab57faa9942ccaff977ef88a (diff)
Add op cost model for MaxPool, AvgPool, FusedBatchNorm, their grad ops, and
ReluGrad. PiperOrigin-RevId: 190821116
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 1b3babb206..fcbecbb6dc 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -145,6 +145,12 @@ class OpLevelCostEstimator {
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;
Costs PredictGatherOrSlice(const OpContext& op_context) const;
+ Costs PredictMaxPool(const OpContext& op_context) const;
+ Costs PredictMaxPoolGrad(const OpContext& op_context) const;
+ Costs PredictAvgPool(const OpContext& op_context) const;
+ Costs PredictAvgPoolGrad(const OpContext& op_context) const;
+ Costs PredictFusedBatchNorm(const OpContext& op_context) const;
+ Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -156,9 +162,15 @@ class OpLevelCostEstimator {
}
}
+ // For convolution and its grad ops.
static ConvolutionDimensions ConvolutionDimensionsFromInputs(
const TensorShapeProto& original_image_shape,
- const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
+ const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
+ bool* found_unknown_shapes);
+
+ // For Pooling, FusedBatchNorm, and their grad ops.
+ static ConvolutionDimensions OpDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
protected: