aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Rob Sloan <varomodt@google.com>2018-04-06 21:55:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 21:57:46 -0700
commit30e2b97897d05e47b457ab1d5d0d9c4227b87845 (patch)
tree19ccb01faec4fc451cfe45867d130277a9116fe7 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent273495dc2c957402f832cae31a438e550db2b7f0 (diff)
Add analytical cost model for FusedConv2DBiasActivation.
PiperOrigin-RevId: 191978272
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc165
1 files changed, 163 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 79735e6cc2..087190ad2a 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -30,6 +30,7 @@ constexpr char kConst[] = "Const";
constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
+constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
@@ -196,6 +197,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
{kConv2dBackpropInput,
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
+ {kFusedConv2dBiasActivation,
+ wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
@@ -545,7 +548,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations(
ops *= conv_dims.kx * conv_dims.ky;
ops *= conv_dims.iz * conv_dims.oz;
ops *= kOpsPerMac;
- VLOG(1) << "Operations for Conv2D " << ops;
if (conv_info != nullptr) {
*conv_info = conv_dims;
@@ -983,6 +985,91 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
return costs;
}
+Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
+ const OpContext& op_context) const {
+ // FusedConv2DBiasActivation computes a fused kernel which implements:
+ // 2D convolution, adds side input with separate scaling on convolution and
+ // side inputs, then adds bias, and finally applies the ReLU activation
+ // function to the result:
+ //
+ // Input -> Conv2D -> Add -> BiasAdd -> ReLU
+ // ^ ^ ^
+ // Filter Side Input Bias
+ //
+ // Note that when adding the side input, the operation multiplies the output
+ // of Conv2D by conv_input_scale, confusingly, and the side_input by
+ // side_input_scale.
+ //
+ // Note that in the special case that side_input_scale is 0, which we infer
+ // from side_input having dimensions [], we skip that addition operation.
+ //
+ // For more information, see
+ // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+ auto& conv_input = op_context.op_info.inputs(0);
+ auto& filter = op_context.op_info.inputs(1);
+ auto& bias = op_context.op_info.inputs(2);
+ auto& side_input = op_context.op_info.inputs(3);
+ auto& conv_input_scale = op_context.op_info.inputs(4);
+ auto& side_input_scale = op_context.op_info.inputs(5);
+
+ // Manually compute our convolution dimensions.
+ bool found_unknown_shapes = false;
+ auto dims = ConvolutionDimensionsFromInputs(
+ conv_input.shape(), filter.shape(), op_context.op_info,
+ &found_unknown_shapes);
+
+ // Construct the shape of our output tensor from our convolution dimensions
+ // and format, as it may not be available yet.
+ //
+ // TODO(varomodt): should we centralize the Conv2D input/output shapes?
+ bool unknown_conv_format = false;
+ OpInfo::TensorProperties output;
+ switch (GetConvolutionFormat(op_context)) {
+ case NCHW:
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ break;
+ case NHWC:
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+ break;
+ default:
+ // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
+ LOG(WARNING) << "unsupported data format: "
+ << GetDataFormat(op_context.op_info)
+ << " Defaulting to NHWC.";
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+ unknown_conv_format = true;
+ break;
+ }
+
+ // Add the operations the fused op always computes.
+ std::vector<OpContext> component_ops = {
+ FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
+ FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
+ FusedChildContext(op_context, "BiasAdd", output, {output, bias}),
+ FusedChildContext(op_context, "Relu", output, {output})};
+
+ // Add our side_input iff it's non-empty.
+ if (side_input.shape().dim_size() > 0) {
+ component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
+ {side_input, side_input_scale}));
+ component_ops.push_back(
+ FusedChildContext(op_context, "Add", output, {side_input, output}));
+ }
+
+ // Construct an op_context which definitely has our output shape.
+ auto op_context_with_output = op_context;
+ op_context_with_output.op_info.mutable_outputs()->Clear();
+ *op_context_with_output.op_info.mutable_outputs()->Add() = output;
+
+ // Construct component operations and run the cost computation.
+ auto costs = PredictFusedOp(op_context_with_output, component_ops);
+ costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ return costs;
+}
+
Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
@@ -1086,6 +1173,66 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
return costs;
}
+Costs OpLevelCostEstimator::PredictFusedOp(
+ const OpContext& op_context,
+ const std::vector<OpContext>& fused_op_contexts) const {
+ // Note that PredictOpCountBasedCost will get the correct memory_time from
+ // the node's inputs and outputs; but we don't want to have to re-implement
+ // the logic for computing the operation count of each of our component
+ // operations here; so we simply add the compute times of each component
+ // operation, then update the execution time.
+ Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+ fused_cost.compute_time = 0;
+ fused_cost.inaccurate = false;
+ for (auto& fused_op : fused_op_contexts) {
+ auto op_cost = PredictCosts(fused_op);
+ fused_cost.compute_time += op_cost.compute_time;
+ fused_cost.inaccurate |= op_cost.inaccurate;
+ }
+
+ CombineCostsAndUpdateExecutionTime(&fused_cost);
+ return fused_cost;
+}
+
+/* static */
+OpContext OpLevelCostEstimator::FusedChildContext(
+ const OpContext& parent, const string& op_name,
+ const OpInfo::TensorProperties& output,
+ const std::vector<OpInfo::TensorProperties>& inputs) {
+ // Setup the base parameters of our new context.
+ OpContext new_context;
+ new_context.name = op_name;
+ new_context.device_name = parent.device_name;
+ new_context.op_info = parent.op_info;
+ new_context.op_info.set_op(op_name);
+
+ // Setup the inputs of our new context.
+ new_context.op_info.mutable_inputs()->Clear();
+ for (const auto& input : inputs) {
+ *new_context.op_info.mutable_inputs()->Add() = input;
+ }
+
+ // Setup the output of our new context.
+ new_context.op_info.mutable_outputs()->Clear();
+ *new_context.op_info.mutable_outputs()->Add() = output;
+
+ return new_context;
+}
+
+/* static */
+OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
+ DataType type, const std::vector<int64>& dims) {
+ OpInfo::TensorProperties ret;
+ ret.set_dtype(type);
+
+ auto shape = ret.mutable_shape();
+ for (const int dim : dims) {
+ shape->add_dim()->set_size(dim);
+ }
+
+ return ret;
+}
+
/* static */
OpLevelCostEstimator::ConvolutionDimensions
OpLevelCostEstimator::OpDimensionsFromInputs(
@@ -1371,6 +1518,21 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
return costs;
}
+/* static */
+OpLevelCostEstimator::ConvolutionFormat
+OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
+ auto data_format = GetDataFormat(op_context.op_info);
+ if (data_format == "NCHW") {
+ return NCHW;
+ } else if (data_format == "NHWC") {
+ return NHWC;
+ } else if (data_format == "NCHW_VECT_C") {
+ return NCHW_VECT_C;
+ }
+
+ return UNKNOWN_CONVOLUTION_FORMAT;
+}
+
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
@@ -1379,6 +1541,5 @@ void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
costs->execution_time = costs->compute_time + costs->memory_time;
}
}
-
} // end namespace grappler
} // end namespace tensorflow