diff options
author | Rob Sloan <varomodt@google.com> | 2018-04-06 21:55:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-06 21:57:46 -0700 |
commit | 30e2b97897d05e47b457ab1d5d0d9c4227b87845 (patch) | |
tree | 19ccb01faec4fc451cfe45867d130277a9116fe7 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 273495dc2c957402f832cae31a438e550db2b7f0 (diff) |
Add analytical cost model for FusedConv2DBiasActivation.
PiperOrigin-RevId: 191978272
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 165 |
1 files changed, 163 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 79735e6cc2..087190ad2a 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -30,6 +30,7 @@ constexpr char kConst[] = "Const"; constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput"; +constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation"; constexpr char kMatMul[] = "MatMul"; constexpr char kSparseMatMul[] = "SparseMatMul"; constexpr char kPlaceholder[] = "Placeholder"; @@ -196,6 +197,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)}, {kConv2dBackpropInput, wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)}, + {kFusedConv2dBiasActivation, + wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)}, {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}, @@ -545,7 +548,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations( ops *= conv_dims.kx * conv_dims.ky; ops *= conv_dims.iz * conv_dims.oz; ops *= kOpsPerMac; - VLOG(1) << "Operations for Conv2D " << ops; if (conv_info != nullptr) { *conv_info = conv_dims; @@ -983,6 +985,91 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter( return costs; } +Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( + const OpContext& op_context) const { + // FusedConv2DBiasActivation computes a fused kernel which implements: + // 2D convolution, adds side input with separate scaling on convolution and + // side inputs, then adds bias, and finally applies the ReLU activation + // function to the result: + // + // Input -> Conv2D -> Add -> BiasAdd -> ReLU + // ^ ^ ^ + // Filter Side Input Bias + // + // Note that when adding the side input, the operation multiplies the output + // of Conv2D by conv_input_scale, confusingly, and the side_input by + // side_input_scale. + // + // Note that in the special case that side_input_scale is 0, which we infer + // from side_input having dimensions [], we skip that addition operation. + // + // For more information, see + // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc + auto& conv_input = op_context.op_info.inputs(0); + auto& filter = op_context.op_info.inputs(1); + auto& bias = op_context.op_info.inputs(2); + auto& side_input = op_context.op_info.inputs(3); + auto& conv_input_scale = op_context.op_info.inputs(4); + auto& side_input_scale = op_context.op_info.inputs(5); + + // Manually compute our convolution dimensions. + bool found_unknown_shapes = false; + auto dims = ConvolutionDimensionsFromInputs( + conv_input.shape(), filter.shape(), op_context.op_info, + &found_unknown_shapes); + + // Construct the shape of our output tensor from our convolution dimensions + // and format, as it may not be available yet. + // + // TODO(varomodt): should we centralize the Conv2D input/output shapes? + bool unknown_conv_format = false; + OpInfo::TensorProperties output; + switch (GetConvolutionFormat(op_context)) { + case NCHW: + output = + DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy}); + break; + case NHWC: + output = + DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); + break; + default: + // TODO(b/77722245): support cost estimation for NCHW_VECT_C. + LOG(WARNING) << "unsupported data format: " + << GetDataFormat(op_context.op_info) + << " Defaulting to NHWC."; + output = + DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); + unknown_conv_format = true; + break; + } + + // Add the operations the fused op always computes. + std::vector<OpContext> component_ops = { + FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}), + FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}), + FusedChildContext(op_context, "BiasAdd", output, {output, bias}), + FusedChildContext(op_context, "Relu", output, {output})}; + + // Add our side_input iff it's non-empty. + if (side_input.shape().dim_size() > 0) { + component_ops.push_back(FusedChildContext(op_context, "Mul", side_input, + {side_input, side_input_scale})); + component_ops.push_back( + FusedChildContext(op_context, "Add", output, {side_input, output})); + } + + // Construct an op_context which definitely has our output shape. + auto op_context_with_output = op_context; + op_context_with_output.op_info.mutable_outputs()->Clear(); + *op_context_with_output.op_info.mutable_outputs()->Add() = output; + + // Construct component operations and run the cost computation. + auto costs = PredictFusedOp(op_context_with_output, component_ops); + costs.inaccurate |= found_unknown_shapes || unknown_conv_format; + return costs; +} + Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const { const auto& op_features = op_context.op_info; bool found_unknown_shapes = false; @@ -1086,6 +1173,66 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( return costs; } +Costs OpLevelCostEstimator::PredictFusedOp( + const OpContext& op_context, + const std::vector<OpContext>& fused_op_contexts) const { + // Note that PredictOpCountBasedCost will get the correct memory_time from + // the node's inputs and outputs; but we don't want to have to re-implement + // the logic for computing the operation count of each of our component + // operations here; so we simply add the compute times of each component + // operation, then update the execution time. + Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info); + fused_cost.compute_time = 0; + fused_cost.inaccurate = false; + for (auto& fused_op : fused_op_contexts) { + auto op_cost = PredictCosts(fused_op); + fused_cost.compute_time += op_cost.compute_time; + fused_cost.inaccurate |= op_cost.inaccurate; + } + + CombineCostsAndUpdateExecutionTime(&fused_cost); + return fused_cost; +} + +/* static */ +OpContext OpLevelCostEstimator::FusedChildContext( + const OpContext& parent, const string& op_name, + const OpInfo::TensorProperties& output, + const std::vector<OpInfo::TensorProperties>& inputs) { + // Setup the base parameters of our new context. + OpContext new_context; + new_context.name = op_name; + new_context.device_name = parent.device_name; + new_context.op_info = parent.op_info; + new_context.op_info.set_op(op_name); + + // Setup the inputs of our new context. + new_context.op_info.mutable_inputs()->Clear(); + for (const auto& input : inputs) { + *new_context.op_info.mutable_inputs()->Add() = input; + } + + // Setup the output of our new context. + new_context.op_info.mutable_outputs()->Clear(); + *new_context.op_info.mutable_outputs()->Add() = output; + + return new_context; +} + +/* static */ +OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor( + DataType type, const std::vector<int64>& dims) { + OpInfo::TensorProperties ret; + ret.set_dtype(type); + + auto shape = ret.mutable_shape(); + for (const int dim : dims) { + shape->add_dim()->set_size(dim); + } + + return ret; +} + /* static */ OpLevelCostEstimator::ConvolutionDimensions OpLevelCostEstimator::OpDimensionsFromInputs( @@ -1371,6 +1518,21 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( return costs; } +/* static */ +OpLevelCostEstimator::ConvolutionFormat +OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) { + auto data_format = GetDataFormat(op_context.op_info); + if (data_format == "NCHW") { + return NCHW; + } else if (data_format == "NHWC") { + return NHWC; + } else if (data_format == "NCHW_VECT_C") { + return NCHW_VECT_C; + } + + return UNKNOWN_CONVOLUTION_FORMAT; +} + void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( Costs* costs) const { if (compute_memory_overlap_) { @@ -1379,6 +1541,5 @@ void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( costs->execution_time = costs->compute_time + costs->memory_time; } } - } // end namespace grappler } // end namespace tensorflow |