aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-08 10:25:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 14:10:44 -0700
commit074d2901e2f6b9807394f300e5ccbc65defcf161 (patch)
tree207afc0f18ceae327507323588bb9044487bdba0 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent77bb984c23aa7ec347c981c31f650598c9624304 (diff)
Add cost model of depthwiseConv2dNative. Tensorflow computes depthwise separable convolutions as depthwiseConv2dNative followed by 1x1 Conv2D
PiperOrigin-RevId: 195838887
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc68
1 files changed, 55 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 199b69452f..2542fa2d67 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -32,6 +32,11 @@ constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
+constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
+constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
+ "DepthwiseConv2dNativeBackpropFilter";
+constexpr char kDepthwiseConv2dNativeBackpropInput[] =
+ "DepthwiseConv2dNativeBackpropInput";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
@@ -201,6 +206,14 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
{kFusedConv2dBiasActivation,
wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
+ // reuse Conv2D for DepthwiseConv2dNative because the caculation is the
+ // same although the actual meaning of the parameters are different. See
+ // comments in PredictConv2D and related functions
+ {kDepthwiseConv2dNative, wrap(&OpLevelCostEstimator::PredictConv2D)},
+ {kDepthwiseConv2dNativeBackpropFilter,
+ wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
+ {kDepthwiseConv2dNativeBackpropInput,
+ wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
@@ -539,18 +552,30 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
int64 OpLevelCostEstimator::CountConv2DOperations(
const OpInfo& op_features, ConvolutionDimensions* conv_info,
bool* found_unknown_shapes) const {
- if (op_features.op() != kConv2d) {
- LOG(ERROR) << "Invalid Operation";
- return 0;
- }
+ DCHECK(op_features.op() == kConv2d ||
+ op_features.op() == kDepthwiseConv2dNative)
+ << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
+
ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features,
found_unknown_shapes);
+ // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
+ // multiplier; The effective output channel depth oz_effective is
+ // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
+ // Compare to Conv2D where # ops = N x H x W x iz x oz x 2RS,
+ // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = iz.
int64 ops = conv_dims.batch;
ops *= conv_dims.ox * conv_dims.oy;
ops *= conv_dims.kx * conv_dims.ky;
- ops *= conv_dims.iz * conv_dims.oz;
+ if (op_features.op() == kConv2d) {
+ ops *= conv_dims.iz * conv_dims.oz;
+ } else {
+ // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
+ // although ops are the same as Conv2D.
+ conv_dims.oz *= conv_dims.iz;
+ ops *= conv_dims.oz;
+ }
ops *= kOpsPerMac;
if (conv_info != nullptr) {
@@ -797,7 +822,10 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
bool* found_unknown_shapes) const {
int64 ops = 0;
- DCHECK_EQ(kConv2dBackpropInput, op_features.op());
+ DCHECK(op_features.op() == kConv2dBackpropInput ||
+ op_features.op() == kDepthwiseConv2dNativeBackpropInput)
+ << "Invalid Operation: not kConv2dBackpropInput nor"
+ "kDepthwiseConv2dNativeBackpropInput";
if (op_features.inputs_size() < 2) {
*found_unknown_shapes = true;
@@ -830,10 +858,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
ops = conv_dims.batch;
ops *= conv_dims.ox * conv_dims.oy;
ops *= conv_dims.kx * conv_dims.ky;
- ops *= conv_dims.iz * conv_dims.oz;
- ops *= kOpsPerMac;
+ if (op_features.op() == kConv2dBackpropInput) {
+ ops *= conv_dims.iz * conv_dims.oz;
+ } else {
+ // conv_dims always use forward path definition regardless
+ conv_dims.oz *= conv_dims.iz;
+ ops *= conv_dims.oz;
+ }
- VLOG(1) << "Operations for Conv2DBackpropInput " << ops;
+ VLOG(1) << "Operations for" << op_features.op() << " " << ops;
if (returned_conv_dims != nullptr) {
*returned_conv_dims = conv_dims;
@@ -845,7 +878,11 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
bool* found_unknown_shapes) const {
int64 ops = 0;
- DCHECK_EQ(kConv2dBackpropFilter, op_features.op());
+
+ DCHECK(op_features.op() == kConv2dBackpropFilter ||
+ op_features.op() == kDepthwiseConv2dNativeBackpropFilter)
+ << "Invalid Operation: not kConv2dBackpropFilter nor"
+ "kDepthwiseConv2dNativeBackpropFilter";
TensorShapeProto filter_shape;
bool shape_found = false;
@@ -877,10 +914,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
ops = conv_dims.batch;
ops *= conv_dims.ox * conv_dims.oy;
ops *= conv_dims.kx * conv_dims.ky;
- ops *= conv_dims.iz * conv_dims.oz;
- ops *= kOpsPerMac;
+ if (op_features.op() == kConv2dBackpropFilter) {
+ ops *= conv_dims.iz * conv_dims.oz;
+ } else {
+ // conv_dims always use forward path definition regardless
+ conv_dims.oz *= conv_dims.iz;
+ ops *= conv_dims.oz;
+ }
- VLOG(1) << "Operations for Conv2DBackpropFilter" << ops;
+ VLOG(1) << "Operations for" << op_features.op() << " " << ops;
if (returned_conv_dims != nullptr) {
*returned_conv_dims = conv_dims;