aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-28 13:11:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 13:13:48 -0700
commit480ac84aa8390e19a54bd2feef3a6069d959bb4e (patch)
tree615d3b34bae3554d50c3f6c64c25d9001e342a59
parent560ef036727c871bab57faa9942ccaff977ef88a (diff)
Add op cost model for MaxPool, AvgPool, FusedBatchNorm, their grad ops, and
ReluGrad. PiperOrigin-RevId: 190821116
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc306
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h14
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc391
3 files changed, 709 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 905cc2a215..0f6307cfdf 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -50,6 +50,12 @@ constexpr char kPreventGradient[] = "PreventGradient";
constexpr char kGather[] = "Gather";
constexpr char kGatherV2[] = "GatherV2";
constexpr char kSlice[] = "Slice";
+constexpr char kMaxPool[] = "MaxPool";
+constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
+constexpr char kAvgPool[] = "AvgPool";
+constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
+constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
+constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
static const Costs::Duration kMinComputeTime(1);
@@ -71,14 +77,39 @@ Padding GetPadding(const OpInfo& op_features) {
return Padding::SAME; // Default padding.
}
+bool IsTraining(const OpInfo& op_info) {
+ if (op_info.attr().find("is_training") != op_info.attr().end() &&
+ op_info.attr().at("is_training").b()) {
+ return true;
+ }
+ return false;
+}
+
+// TODO(dyoon): support non-4D tensors in the c ost functions of convolution
+// related ops (Conv, Pool, BatchNorm, and their backprops) and the related
+// helper functions.
std::vector<int64> GetStrides(const OpInfo& op_features) {
if (op_features.attr().find("strides") != op_features.attr().end()) {
const auto strides = op_features.attr().at("strides").list().i();
+ CHECK(strides.size() == 4) << "Attr strides is not a length-4 vector: "
+ << op_features.DebugString();
return {strides[0], strides[1], strides[2], strides[3]};
}
return {1, 1, 1, 1};
}
+std::vector<int64> GetKernelSize(const OpInfo& op_info) {
+ if (op_info.attr().find("ksize") != op_info.attr().end()) {
+ const auto ksize = op_info.attr().at("ksize").list().i();
+ CHECK(ksize.size() == 4)
+ << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
+ return {ksize[0], ksize[1], ksize[2], ksize[3]};
+ }
+ // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
+ // {1, 1, 1, 1} in that case.
+ return {1, 1, 1, 1};
+}
+
int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
const Padding& padding) {
// Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
@@ -193,7 +224,15 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
- {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)}};
+ {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)},
+ {kMaxPool, wrap(&OpLevelCostEstimator::PredictMaxPool)},
+ {kMaxPoolGrad, wrap(&OpLevelCostEstimator::PredictMaxPoolGrad)},
+ {kAvgPool, wrap(&OpLevelCostEstimator::PredictAvgPool)},
+ {kAvgPoolGrad, wrap(&OpLevelCostEstimator::PredictAvgPoolGrad)},
+ {kFusedBatchNorm, wrap(&OpLevelCostEstimator::PredictFusedBatchNorm)},
+ {kFusedBatchNormGrad,
+ wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
+ };
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
@@ -258,6 +297,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{"QuantizedAdd", EIGEN_COST(scalar_sum_op<float>)},
{"QuantizedMul", EIGEN_COST(scalar_product_op<float>)},
{"RealDiv", EIGEN_COST(scalar_quotient_op<float>)},
+ {"ReluGrad", EIGEN_COST(scalar_max_op<float>)},
{"SquareDifference", 1},
{"Sub", EIGEN_COST(scalar_difference_op<float>)},
{"TruncateDiv", EIGEN_COST(scalar_quotient_op<float>)},
@@ -1044,5 +1084,269 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
return costs;
}
+/* static */
+OpLevelCostEstimator::ConvolutionDimensions
+OpLevelCostEstimator::OpDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape, const OpInfo& op_info,
+ bool* found_unknown_shapes) {
+ VLOG(2) << "op features: " << op_info.DebugString();
+ VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
+ auto image_shape =
+ MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
+ VLOG(2) << "Image shape: " << image_shape.DebugString();
+
+ int x_index, y_index, channel_index;
+ const string& data_format = GetDataFormat(op_info);
+ if (data_format == "NCHW") {
+ x_index = 2;
+ y_index = 3;
+ channel_index = 1;
+ } else {
+ x_index = 1;
+ y_index = 2;
+ channel_index = 3;
+ }
+ int64 batch = image_shape.dim(0).size();
+ int64 ix = image_shape.dim(x_index).size();
+ int64 iy = image_shape.dim(y_index).size();
+ int64 iz = image_shape.dim(channel_index).size();
+
+ // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
+ // {1, 1, 1, 1} in that case.
+ std::vector<int64> ksize = GetKernelSize(op_info);
+ int64 kx = ksize[x_index];
+ int64 ky = ksize[y_index];
+
+ std::vector<int64> strides = GetStrides(op_info);
+ int64 sx = strides[x_index];
+ int64 sy = strides[y_index];
+ const auto padding = GetPadding(op_info);
+
+ int64 ox = GetOutputSize(ix, kx, sx, padding);
+ int64 oy = GetOutputSize(iy, ky, sy, padding);
+ int64 oz = iz;
+
+ OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
+ batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
+ return conv_dims;
+}
+
+Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+ // kx * ky - 1 comparisons per output (kx * xy > 1)
+ // or 1 copy per output (kx * k1 = 1).
+ int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
+ int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
+
+ double total_input_size = 0;
+ if (dims.ky >= dims.sy) {
+ total_input_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ } else { // dims.ky < dims.sy
+ // Vertical stride is larger than vertical kernel; assuming row-major
+ // format, skip unnecessary rows (or read every kx rows per sy rows, as the
+ // others are not used for output).
+ const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
+ total_input_size =
+ data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
+ }
+ const double total_output_size =
+ CalculateOutputSize(op_info, &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictMaxPoolGrad(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ // y: op_info.inputs(1)
+ // y_grad: op_info.inputs(2)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ int64 ops = 0;
+ if (dims.kx == 1 && dims.ky == 1) {
+ // 1x1 window. No need to know which input was max.
+ ops = dims.batch * dims.ix * dims.iy * dims.iz;
+ } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
+ // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
+ ops = dims.batch * dims.iz *
+ (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
+ } else {
+ // Overlapping window: initialize with zeros, re-run maxpool, then
+ // accumulate y_gad to proper x_grad locations.
+ ops = dims.batch * dims.iz *
+ (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
+ }
+
+ // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
+ // MaxPool internally.
+ double total_input_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ total_input_size +=
+ CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
+ // Write x_grad; size equal to x.
+ const double total_output_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ // kx * ky - 1 additions and 1 multiplication per output.
+ int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
+
+ double total_input_size = 0;
+ if (dims.ky >= dims.sy) {
+ total_input_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ } else { // dims.ky < dims.sy
+ // vertical stride is larger than vertical kernel; assuming row-major
+ // format, skip unnecessary rows (or read every kx rows per sy rows, as the
+ // others are not used for output).
+ const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
+ total_input_size =
+ data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
+ }
+ const double total_output_size =
+ CalculateOutputSize(op_info, &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictAvgPoolGrad(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ // y_grad: op_info.inputs(1)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ int64 ops = 0;
+ if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
+ // Non-overlapping window.
+ ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
+ } else {
+ // Overlapping window.
+ ops = dims.batch * dims.iz *
+ (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
+ }
+
+ const double total_input_size =
+ CalculateInputSize(op_info, &found_unknown_shapes);
+ const double total_output_size =
+ CalculateOutputSize(op_info, &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictFusedBatchNorm(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ // scale: op_info.inputs(1)
+ // offset: op_info.inputs(2)
+ // mean: op_info.inputs(3) --> only for inference
+ // variance: op_info.inputs(4) --> only for inference
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+ const bool is_training = IsTraining(op_info);
+
+ int64 ops = 0;
+ const auto rsqrt_cost = Eigen::internal::functor_traits<
+ Eigen::internal::scalar_rsqrt_op<float>>::Cost;
+ if (is_training) {
+ ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
+ } else {
+ ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
+ }
+
+ const double size_nhwc =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ const double size_c =
+ CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
+ double total_input_size = 0.0;
+ double total_internal_read_size = 0.0;
+ double total_output_size = 0.0;
+ if (is_training) {
+ total_input_size = size_nhwc + size_c * 2;
+ total_output_size = size_nhwc + size_c * 4;
+ total_internal_read_size = size_nhwc;
+ } else {
+ total_input_size = size_nhwc + size_c * 4;
+ total_output_size = size_nhwc;
+ }
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size + total_internal_read_size,
+ op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // y_backprop: op_info.inputs(0)
+ // x: op_info.inputs(1)
+ // scale: op_info.inputs(2)
+ // mean: op_info.inputs(3)
+ // variance or inverse of variance: op_info.inputs(4)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
+
+ int64 ops = 0;
+ const auto rsqrt_cost = Eigen::internal::functor_traits<
+ Eigen::internal::scalar_rsqrt_op<float>>::Cost;
+ ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
+
+ const double size_nhwc =
+ CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
+ const double size_c =
+ CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
+ double total_input_size = size_nhwc * 2 + size_c * 2;
+ double total_internal_read_size = size_nhwc;
+ double total_output_size = size_nhwc * 1 + size_c * 2;
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size + total_internal_read_size,
+ op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 1b3babb206..fcbecbb6dc 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -145,6 +145,12 @@ class OpLevelCostEstimator {
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;
Costs PredictGatherOrSlice(const OpContext& op_context) const;
+ Costs PredictMaxPool(const OpContext& op_context) const;
+ Costs PredictMaxPoolGrad(const OpContext& op_context) const;
+ Costs PredictAvgPool(const OpContext& op_context) const;
+ Costs PredictAvgPoolGrad(const OpContext& op_context) const;
+ Costs PredictFusedBatchNorm(const OpContext& op_context) const;
+ Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -156,9 +162,15 @@ class OpLevelCostEstimator {
}
}
+ // For convolution and its grad ops.
static ConvolutionDimensions ConvolutionDimensionsFromInputs(
const TensorShapeProto& original_image_shape,
- const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
+ const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
+ bool* found_unknown_shapes);
+
+ // For Pooling, FusedBatchNorm, and their grad ops.
+ static ConvolutionDimensions OpDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
protected:
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 99bf28f21b..56915ed821 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -169,6 +171,130 @@ OpContext DescribeBiasAdd(int size1, int size2) {
return op_context;
}
+int GetOutputSize(const int x, const int k, const int s,
+ const string& padding) {
+ if (padding == "SAME") {
+ return (x + s - 1) / s;
+ } else {
+ return (x - k + s) / s;
+ }
+}
+
+std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
+ const std::vector<int>& ksize,
+ const std::vector<int>& strides,
+ const string& data_format,
+ const string& padding) {
+ // h, w, and c indices: default with NHWC.
+ int h_index = 1;
+ int w_index = 2;
+ int c_index = 3;
+ if (data_format == "NCHW") {
+ h_index = 2;
+ w_index = 3;
+ c_index = 1;
+ }
+ // Extract parameters.
+ int n = input[0];
+ int h = input[h_index];
+ int w = input[w_index];
+ int c = input[c_index];
+ int sx = strides[h_index];
+ int sy = strides[w_index];
+ int kx = ksize[h_index];
+ int ky = ksize[w_index];
+
+ // Output activation size: default with VALID padding.
+ int ho = GetOutputSize(h, kx, sx, padding);
+ int wo = GetOutputSize(w, ky, sy, padding);
+
+ std::vector<int> output;
+ if (data_format == "NHWC") {
+ output = {n, ho, wo, c};
+ } else {
+ output = {n, c, ho, wo};
+ }
+ return output;
+}
+
+OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
+ const std::vector<int>& ksize,
+ const std::vector<int>& strides,
+ const string& data_format, const string& padding) {
+ OpContext op_context;
+ auto& op_info = op_context.op_info;
+ SetCpuDevice(&op_info);
+ op_info.set_op(op_name);
+
+ const std::vector<int> y =
+ GetPoolingOutputSize(x, ksize, strides, data_format, padding);
+ if (op_name == "AvgPool" || op_name == "MaxPool") {
+ // input: x, output: y.
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
+ } else if (op_name == "AvgPoolGrad") {
+ // input: x, y_grad, output: x_grad.
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
+ } else if (op_name == "MaxPoolGrad") {
+ // input: x, y, y_grad, output: x_grad.
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
+ }
+ auto* attr = op_info.mutable_attr();
+ SetAttrValue(data_format, &(*attr)["data_format"]);
+ SetAttrValue(padding, &(*attr)["padding"]);
+ SetAttrValue(strides, &(*attr)["strides"]);
+ SetAttrValue(ksize, &(*attr)["ksize"]);
+ return op_context;
+}
+
+OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
+ const std::vector<int>& x,
+ const string& data_format) {
+ // First, get MaxPool op info with unit stride and unit window.
+ OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
+ {1, 1, 1, 1}, data_format, "SAME");
+ auto& op_info = op_context.op_info;
+ // Override op name.
+ if (is_grad) {
+ op_info.set_op("FusedBatchNormGrad");
+ } else {
+ op_info.set_op("FusedBatchNorm");
+ }
+
+ // Add additional input output tensors.
+ if (is_grad) {
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ }
+ int num_1d_inputs = is_grad ? 3 : 4;
+ for (int i = 0; i < num_1d_inputs; i++) {
+ auto* tensor = op_info.add_inputs();
+ auto* shape = tensor->mutable_shape();
+ shape->add_dim()->set_size(x[3]);
+ tensor->set_dtype(DT_FLOAT);
+ }
+ for (int i = 0; i < 4; i++) {
+ auto* tensor = op_info.add_outputs();
+ auto* shape = tensor->mutable_shape();
+ shape->add_dim()->set_size(x[3]);
+ tensor->set_dtype(DT_FLOAT);
+ }
+
+ // Delete unnecessary attr.
+ auto* attr = op_context.op_info.mutable_attr();
+ attr->erase("ksize");
+ attr->erase("strides");
+ attr->erase("padding");
+
+ // Additional attrs for FusedBatchNorm.
+ SetAttrValue(is_training, &(*attr)["is_training"]);
+
+ return op_context;
+}
} // namespace
class OpLevelCostEstimatorTest : public ::testing::Test {
@@ -192,6 +318,50 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
estimator_.compute_memory_overlap_ = value;
}
+ void ValidateOpDimensionsFromImputs(const int n, const int h, const int w,
+ const int c, const int kx, const int ky,
+ const int sx, const int sy,
+ const string& data_format,
+ const string& padding) {
+ OpContext op_context;
+ int ho;
+ int wo;
+ if (data_format == "NHWC") {
+ op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
+ {1, sx, sy, 1}, "NHWC", padding);
+ ho = op_context.op_info.outputs(0).shape().dim(1).size();
+ wo = op_context.op_info.outputs(0).shape().dim(2).size();
+ } else {
+ op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
+ {1, 1, sx, sy}, "NCHW", padding);
+ ho = op_context.op_info.outputs(0).shape().dim(2).size();
+ wo = op_context.op_info.outputs(0).shape().dim(3).size();
+ }
+
+ bool found_unknown_shapes;
+ auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
+ op_context.op_info.inputs(0).shape(), op_context.op_info,
+ &found_unknown_shapes);
+ Padding padding_enum;
+ if (padding == "VALID") {
+ padding_enum = Padding::VALID;
+ } else {
+ padding_enum = Padding::SAME;
+ }
+ EXPECT_EQ(n, dims.batch);
+ EXPECT_EQ(h, dims.ix);
+ EXPECT_EQ(w, dims.iy);
+ EXPECT_EQ(c, dims.iz);
+ EXPECT_EQ(kx, dims.kx);
+ EXPECT_EQ(ky, dims.ky);
+ EXPECT_EQ(sx, dims.sx);
+ EXPECT_EQ(sy, dims.sy);
+ EXPECT_EQ(ho, dims.ox);
+ EXPECT_EQ(wo, dims.oy);
+ EXPECT_EQ(c, dims.oz);
+ EXPECT_EQ(padding_enum, dims.padding);
+ }
+
OpLevelCostEstimator estimator_;
};
@@ -443,5 +613,226 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
}
}
+TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
+ std::vector<string> paddings = {"VALID", "SAME"};
+ std::vector<string> formats = {"NHWC", "NCHW"};
+ for (const auto& p : paddings) {
+ for (const auto& f : formats) {
+ // n, h, w, c, kx, ky, sx, sy, data_format, padding.
+ ValidateOpDimensionsFromImputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
+ ValidateOpDimensionsFromImputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
+ ValidateOpDimensionsFromImputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
+ ValidateOpDimensionsFromImputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
+ }
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
+ auto predict_max_pool = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context = DescribePoolingOp(
+ "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
+ auto predict_max_pool_grad = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context =
+ DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
+ {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
+ auto predict_avg_pool = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context = DescribePoolingOp(
+ "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
+ auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context =
+ DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
+ {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1920000), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1574400), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(1476480), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
+ auto predict_fused_bn = [this](const int n, const int in, const int c,
+ const bool is_training) -> Costs {
+ OpContext op_context = DescribeFusedBatchNorm(
+ is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
+ EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
+ EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
+ EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
+ EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
+ auto predict_fused_bn_grad = [this](const int n, const int in,
+ const int c) -> Costs {
+ OpContext op_context = DescribeFusedBatchNorm(
+ /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ auto costs = predict_fused_bn_grad(10, 20, 96);
+ EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn_grad(128, 7, 384);
+ EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
} // end namespace grappler
} // end namespace tensorflow