aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-06-20 18:36:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 18:39:22 -0700
commit96dfcc2fdc9f3a7419d3d5c5a64489e757de624e (patch)
tree8c684731bde1643158037bf1d4ed17e58c95096a /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parente8b18a6f0c02d364ff47ba5fa3dc61458d273674 (diff)
Support filter format for FusedConv2DBiasActivation.
PiperOrigin-RevId: 201454730
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc93
1 files changed, 52 insertions, 41 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b994d26397..d34eecd009 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -78,6 +78,14 @@ string GetDataFormat(const OpInfo& op_features) {
return data_format;
}
+string GetFilterFormat(const OpInfo& op_features) {
+ string filter_format = "HWIO"; // Default format.
+ if (op_features.attr().find("filter_format") != op_features.attr().end()) {
+ filter_format = op_features.attr().at("filter_format").s();
+ }
+ return filter_format;
+}
+
Padding GetPadding(const OpInfo& op_features) {
if (op_features.attr().find("padding") != op_features.attr().end() &&
op_features.attr().at("padding").s() == "VALID") {
@@ -513,29 +521,44 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
y_index = 3;
channel_index = 1;
} else {
+ // Use NHWC.
x_index = 1;
y_index = 2;
channel_index = 3;
}
+ const string& filter_format = GetFilterFormat(op_features);
+ int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
+ if (filter_format == "HWIO") {
+ filter_x_index = 0;
+ filter_y_index = 1;
+ in_channel_index = 2;
+ out_channel_index = 3;
+ } else {
+ // Use OIHW
+ filter_x_index = 2;
+ filter_y_index = 3;
+ in_channel_index = 1;
+ out_channel_index = 0;
+ }
int64 batch = image_shape.dim(0).size();
int64 ix = image_shape.dim(x_index).size();
int64 iy = image_shape.dim(y_index).size();
int64 iz = image_shape.dim(channel_index).size();
- int64 kx = filter_shape.dim(0).size();
- int64 ky = filter_shape.dim(1).size();
+ int64 kx = filter_shape.dim(filter_x_index).size();
+ int64 ky = filter_shape.dim(filter_y_index).size();
std::vector<int64> strides = GetStrides(op_features);
const auto padding = GetPadding(op_features);
int64 sx = strides[x_index];
int64 sy = strides[y_index];
int64 ox = GetOutputSize(ix, kx, sx, padding);
int64 oy = GetOutputSize(iy, ky, sy, padding);
- int64 oz = filter_shape.dim(3).size();
+ int64 oz = filter_shape.dim(out_channel_index).size();
// Only check equality when both sizes are known (in other words, when
// neither is set to a minimum dimension size of 1).
- if (iz != 1 && filter_shape.dim(2).size() != 1) {
- CHECK_EQ(iz, filter_shape.dim(2).size());
+ if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) {
+ CHECK_EQ(iz, filter_shape.dim(in_channel_index).size());
} else {
- iz = std::max<int64>(iz, filter_shape.dim(2).size());
+ iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size());
}
OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
@@ -1054,6 +1077,24 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
//
// For more information, see
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+
+ // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
+ // filter formats (OIHW_VECT_I).
+ string data_format = GetDataFormat(op_context.op_info);
+ if (data_format != "NCHW" && data_format != "NHWC") {
+ LOG(WARNING) << "unsupported data format: " << data_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+ string filter_format = GetFilterFormat(op_context.op_info);
+ if (filter_format != "HWIO" && filter_format != "OIHW") {
+ LOG(WARNING) << "unsupported filter format: " << filter_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+
auto& conv_input = op_context.op_info.inputs(0);
auto& filter = op_context.op_info.inputs(1);
auto& bias = op_context.op_info.inputs(2);
@@ -1069,28 +1110,12 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct the shape of our output tensor from our convolution dimensions
// and format, as it may not be available yet.
- //
// TODO(varomodt): should we centralize the Conv2D input/output shapes?
- bool unknown_conv_format = false;
OpInfo::TensorProperties output;
- switch (GetConvolutionFormat(op_context)) {
- case NCHW:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
- break;
- case NHWC:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- break;
- default:
- // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
- LOG(WARNING) << "unsupported data format: "
- << GetDataFormat(op_context.op_info)
- << " Defaulting to NHWC.";
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- unknown_conv_format = true;
- break;
+ if (data_format == "NCHW") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ } else if (data_format == "NHWC") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
}
// Add the operations the fused op always computes.
@@ -1115,7 +1140,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct component operations and run the cost computation.
auto costs = PredictFusedOp(op_context_with_output, component_ops);
- costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ costs.inaccurate |= found_unknown_shapes;
return costs;
}
@@ -1568,20 +1593,6 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
}
/* static */
-OpLevelCostEstimator::ConvolutionFormat
-OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
- auto data_format = GetDataFormat(op_context.op_info);
- if (data_format == "NCHW") {
- return NCHW;
- } else if (data_format == "NHWC") {
- return NHWC;
- } else if (data_format == "NCHW_VECT_C") {
- return NCHW_VECT_C;
- }
-
- return UNKNOWN_CONVOLUTION_FORMAT;
-}
-
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {