diff options
author | 2018-04-06 21:55:10 -0700 | |
---|---|---|
committer | 2018-04-06 21:57:46 -0700 | |
commit | 30e2b97897d05e47b457ab1d5d0d9c4227b87845 (patch) | |
tree | 19ccb01faec4fc451cfe45867d130277a9116fe7 /tensorflow/core/grappler/costs/op_level_cost_estimator.h | |
parent | 273495dc2c957402f832cae31a438e550db2b7f0 (diff) |
Add analytical cost model for FusedConv2DBiasActivation.
PiperOrigin-RevId: 191978272
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 7080264698..35649f7ee9 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -82,6 +82,13 @@ class OpLevelCostEstimator { int64 sy; // Stride y. Padding padding; // SAME or VALID. }; + enum ConvolutionFormat { + UNKNOWN_CONVOLUTION_FORMAT, + NHWC, + NCHW, + NCHW_VECT_C, + NCHW_VECT_W, + }; int64 CountConv2DOperations(const OpInfo& op_features, bool* found_unknown_shapes) const; int64 CountConv2DOperations(const OpInfo& op_features, @@ -138,6 +145,7 @@ class OpLevelCostEstimator { Costs PredictCwiseOp(const OpContext& op_context) const; Costs PredictConv2DBackpropInput(const OpContext& op_context) const; Costs PredictConv2DBackpropFilter(const OpContext& op_context) const; + Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const; Costs PredictMatMul(const OpContext& op_context) const; Costs PredictNoOp(const OpContext& op_context) const; Costs PredictIdentity(const OpContext& op_context) const; @@ -152,6 +160,10 @@ class OpLevelCostEstimator { Costs PredictFusedBatchNorm(const OpContext& op_context) const; Costs PredictFusedBatchNormGrad(const OpContext& op_context) const; + // Generic cost prediction method for fused operations. + Costs PredictFusedOp(const OpContext& op_context, + const std::vector<OpContext>& fused_op_contexts) const; + // Utility function for safe division. Returns 0 // if rhs is 0 or negative. static double SafeDiv(const double lhs, const double rhs) { @@ -173,6 +185,20 @@ class OpLevelCostEstimator { const TensorShapeProto& original_image_shape, const OpInfo& op_info, bool* found_unknown_shapes); + // Helper to construct child operation contexts for the component operations + // of fused ops. + static OpContext FusedChildContext( + const OpContext& parent, const string& op_name, + const OpInfo::TensorProperties& output, + const std::vector<OpInfo::TensorProperties>& inputs); + + // Helper to construct tensor shapes. + static OpInfo::TensorProperties DescribeTensor( + DataType type, const std::vector<int64>& dims); + + // Returns the Conv2D format for this operation. + static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context); + // This method calculates the execution time depending on whether IO can // overlap with computation. It assumes the memory and the compute times have // already been calculated. |