diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-08 10:25:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-08 14:10:44 -0700 |
commit | 074d2901e2f6b9807394f300e5ccbc65defcf161 (patch) | |
tree | 207afc0f18ceae327507323588bb9044487bdba0 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 77bb984c23aa7ec347c981c31f650598c9624304 (diff) |
Add cost model of depthwiseConv2dNative. Tensorflow computes depthwise separable convolutions as depthwiseConv2dNative followed by 1x1 Conv2D
PiperOrigin-RevId: 195838887
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 68 |
1 files changed, 55 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 199b69452f..2542fa2d67 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -32,6 +32,11 @@ constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput"; constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation"; +constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative"; +constexpr char kDepthwiseConv2dNativeBackpropFilter[] = + "DepthwiseConv2dNativeBackpropFilter"; +constexpr char kDepthwiseConv2dNativeBackpropInput[] = + "DepthwiseConv2dNativeBackpropInput"; constexpr char kMatMul[] = "MatMul"; constexpr char kSparseMatMul[] = "SparseMatMul"; constexpr char kPlaceholder[] = "Placeholder"; @@ -201,6 +206,14 @@ OpLevelCostEstimator::OpLevelCostEstimator() { wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)}, {kFusedConv2dBiasActivation, wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)}, + // reuse Conv2D for DepthwiseConv2dNative because the caculation is the + // same although the actual meaning of the parameters are different. See + // comments in PredictConv2D and related functions + {kDepthwiseConv2dNative, wrap(&OpLevelCostEstimator::PredictConv2D)}, + {kDepthwiseConv2dNativeBackpropFilter, + wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)}, + {kDepthwiseConv2dNativeBackpropInput, + wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)}, {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}, @@ -539,18 +552,30 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( int64 OpLevelCostEstimator::CountConv2DOperations( const OpInfo& op_features, ConvolutionDimensions* conv_info, bool* found_unknown_shapes) const { - if (op_features.op() != kConv2d) { - LOG(ERROR) << "Invalid Operation"; - return 0; - } + DCHECK(op_features.op() == kConv2d || + op_features.op() == kDepthwiseConv2dNative) + << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative"; + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features, found_unknown_shapes); + // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth + // multiplier; The effective output channel depth oz_effective is + // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS. + // Compare to Conv2D where # ops = N x H x W x iz x oz x 2RS, + // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = iz. int64 ops = conv_dims.batch; ops *= conv_dims.ox * conv_dims.oy; ops *= conv_dims.kx * conv_dims.ky; - ops *= conv_dims.iz * conv_dims.oz; + if (op_features.op() == kConv2d) { + ops *= conv_dims.iz * conv_dims.oz; + } else { + // To ensure output tensor dims to be correct for DepthwiseConv2DNative, + // although ops are the same as Conv2D. + conv_dims.oz *= conv_dims.iz; + ops *= conv_dims.oz; + } ops *= kOpsPerMac; if (conv_info != nullptr) { @@ -797,7 +822,10 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( bool* found_unknown_shapes) const { int64 ops = 0; - DCHECK_EQ(kConv2dBackpropInput, op_features.op()); + DCHECK(op_features.op() == kConv2dBackpropInput || + op_features.op() == kDepthwiseConv2dNativeBackpropInput) + << "Invalid Operation: not kConv2dBackpropInput nor" + "kDepthwiseConv2dNativeBackpropInput"; if (op_features.inputs_size() < 2) { *found_unknown_shapes = true; @@ -830,10 +858,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( ops = conv_dims.batch; ops *= conv_dims.ox * conv_dims.oy; ops *= conv_dims.kx * conv_dims.ky; - ops *= conv_dims.iz * conv_dims.oz; - ops *= kOpsPerMac; + if (op_features.op() == kConv2dBackpropInput) { + ops *= conv_dims.iz * conv_dims.oz; + } else { + // conv_dims always use forward path definition regardless + conv_dims.oz *= conv_dims.iz; + ops *= conv_dims.oz; + } - VLOG(1) << "Operations for Conv2DBackpropInput " << ops; + VLOG(1) << "Operations for" << op_features.op() << " " << ops; if (returned_conv_dims != nullptr) { *returned_conv_dims = conv_dims; @@ -845,7 +878,11 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, bool* found_unknown_shapes) const { int64 ops = 0; - DCHECK_EQ(kConv2dBackpropFilter, op_features.op()); + + DCHECK(op_features.op() == kConv2dBackpropFilter || + op_features.op() == kDepthwiseConv2dNativeBackpropFilter) + << "Invalid Operation: not kConv2dBackpropFilter nor" + "kDepthwiseConv2dNativeBackpropFilter"; TensorShapeProto filter_shape; bool shape_found = false; @@ -877,10 +914,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( ops = conv_dims.batch; ops *= conv_dims.ox * conv_dims.oy; ops *= conv_dims.kx * conv_dims.ky; - ops *= conv_dims.iz * conv_dims.oz; - ops *= kOpsPerMac; + if (op_features.op() == kConv2dBackpropFilter) { + ops *= conv_dims.iz * conv_dims.oz; + } else { + // conv_dims always use forward path definition regardless + conv_dims.oz *= conv_dims.iz; + ops *= conv_dims.oz; + } - VLOG(1) << "Operations for Conv2DBackpropFilter" << ops; + VLOG(1) << "Operations for" << op_features.op() << " " << ops; if (returned_conv_dims != nullptr) { *returned_conv_dims = conv_dims; |