diff options
author | Yao Zhang <yaozhang@google.com> | 2018-06-20 18:36:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 18:39:22 -0700 |
commit | 96dfcc2fdc9f3a7419d3d5c5a64489e757de624e (patch) | |
tree | 8c684731bde1643158037bf1d4ed17e58c95096a /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | e8b18a6f0c02d364ff47ba5fa3dc61458d273674 (diff) |
Support filter format for FusedConv2DBiasActivation.
PiperOrigin-RevId: 201454730
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 93 |
1 files changed, 52 insertions, 41 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index b994d26397..d34eecd009 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -78,6 +78,14 @@ string GetDataFormat(const OpInfo& op_features) { return data_format; } +string GetFilterFormat(const OpInfo& op_features) { + string filter_format = "HWIO"; // Default format. + if (op_features.attr().find("filter_format") != op_features.attr().end()) { + filter_format = op_features.attr().at("filter_format").s(); + } + return filter_format; +} + Padding GetPadding(const OpInfo& op_features) { if (op_features.attr().find("padding") != op_features.attr().end() && op_features.attr().at("padding").s() == "VALID") { @@ -513,29 +521,44 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( y_index = 3; channel_index = 1; } else { + // Use NHWC. x_index = 1; y_index = 2; channel_index = 3; } + const string& filter_format = GetFilterFormat(op_features); + int filter_x_index, filter_y_index, in_channel_index, out_channel_index; + if (filter_format == "HWIO") { + filter_x_index = 0; + filter_y_index = 1; + in_channel_index = 2; + out_channel_index = 3; + } else { + // Use OIHW + filter_x_index = 2; + filter_y_index = 3; + in_channel_index = 1; + out_channel_index = 0; + } int64 batch = image_shape.dim(0).size(); int64 ix = image_shape.dim(x_index).size(); int64 iy = image_shape.dim(y_index).size(); int64 iz = image_shape.dim(channel_index).size(); - int64 kx = filter_shape.dim(0).size(); - int64 ky = filter_shape.dim(1).size(); + int64 kx = filter_shape.dim(filter_x_index).size(); + int64 ky = filter_shape.dim(filter_y_index).size(); std::vector<int64> strides = GetStrides(op_features); const auto padding = GetPadding(op_features); int64 sx = strides[x_index]; int64 sy = strides[y_index]; int64 ox = GetOutputSize(ix, kx, sx, padding); int64 oy = GetOutputSize(iy, ky, sy, padding); - int64 oz = filter_shape.dim(3).size(); + int64 oz = filter_shape.dim(out_channel_index).size(); // Only check equality when both sizes are known (in other words, when // neither is set to a minimum dimension size of 1). - if (iz != 1 && filter_shape.dim(2).size() != 1) { - CHECK_EQ(iz, filter_shape.dim(2).size()); + if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) { + CHECK_EQ(iz, filter_shape.dim(in_channel_index).size()); } else { - iz = std::max<int64>(iz, filter_shape.dim(2).size()); + iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size()); } OpLevelCostEstimator::ConvolutionDimensions conv_dims = { batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding}; @@ -1054,6 +1077,24 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // // For more information, see // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc + + // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and + // filter formats (OIHW_VECT_I). + string data_format = GetDataFormat(op_context.op_info); + if (data_format != "NCHW" && data_format != "NHWC") { + LOG(WARNING) << "unsupported data format: " << data_format; + Costs cost = Costs::ZeroCosts(); + cost.inaccurate = true; + return cost; + } + string filter_format = GetFilterFormat(op_context.op_info); + if (filter_format != "HWIO" && filter_format != "OIHW") { + LOG(WARNING) << "unsupported filter format: " << filter_format; + Costs cost = Costs::ZeroCosts(); + cost.inaccurate = true; + return cost; + } + auto& conv_input = op_context.op_info.inputs(0); auto& filter = op_context.op_info.inputs(1); auto& bias = op_context.op_info.inputs(2); @@ -1069,28 +1110,12 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // Construct the shape of our output tensor from our convolution dimensions // and format, as it may not be available yet. - // // TODO(varomodt): should we centralize the Conv2D input/output shapes? - bool unknown_conv_format = false; OpInfo::TensorProperties output; - switch (GetConvolutionFormat(op_context)) { - case NCHW: - output = - DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy}); - break; - case NHWC: - output = - DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); - break; - default: - // TODO(b/77722245): support cost estimation for NCHW_VECT_C. - LOG(WARNING) << "unsupported data format: " - << GetDataFormat(op_context.op_info) - << " Defaulting to NHWC."; - output = - DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); - unknown_conv_format = true; - break; + if (data_format == "NCHW") { + output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy}); + } else if (data_format == "NHWC") { + output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); } // Add the operations the fused op always computes. @@ -1115,7 +1140,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // Construct component operations and run the cost computation. auto costs = PredictFusedOp(op_context_with_output, component_ops); - costs.inaccurate |= found_unknown_shapes || unknown_conv_format; + costs.inaccurate |= found_unknown_shapes; return costs; } @@ -1568,20 +1593,6 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( } /* static */ -OpLevelCostEstimator::ConvolutionFormat -OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) { - auto data_format = GetDataFormat(op_context.op_info); - if (data_format == "NCHW") { - return NCHW; - } else if (data_format == "NHWC") { - return NHWC; - } else if (data_format == "NCHW_VECT_C") { - return NCHW_VECT_C; - } - - return UNKNOWN_CONVOLUTION_FORMAT; -} - void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( Costs* costs) const { if (compute_memory_overlap_) { |