aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rob Sloan <varomodt@google.com>2018-04-06 21:55:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 21:57:46 -0700
commit30e2b97897d05e47b457ab1d5d0d9c4227b87845 (patch)
tree19ccb01faec4fc451cfe45867d130277a9116fe7
parent273495dc2c957402f832cae31a438e550db2b7f0 (diff)
Add analytical cost model for FusedConv2DBiasActivation.
PiperOrigin-RevId: 191978272
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc165
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h26
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc64
3 files changed, 249 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 79735e6cc2..087190ad2a 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -30,6 +30,7 @@ constexpr char kConst[] = "Const";
constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
+constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
@@ -196,6 +197,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
{kConv2dBackpropInput,
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
+ {kFusedConv2dBiasActivation,
+ wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
@@ -545,7 +548,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations(
ops *= conv_dims.kx * conv_dims.ky;
ops *= conv_dims.iz * conv_dims.oz;
ops *= kOpsPerMac;
- VLOG(1) << "Operations for Conv2D " << ops;
if (conv_info != nullptr) {
*conv_info = conv_dims;
@@ -983,6 +985,91 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
return costs;
}
+Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
+ const OpContext& op_context) const {
+ // FusedConv2DBiasActivation computes a fused kernel which implements:
+ // 2D convolution, adds side input with separate scaling on convolution and
+ // side inputs, then adds bias, and finally applies the ReLU activation
+ // function to the result:
+ //
+ // Input -> Conv2D -> Add -> BiasAdd -> ReLU
+ // ^ ^ ^
+ // Filter Side Input Bias
+ //
+ // Note that when adding the side input, the operation multiplies the output
+ // of Conv2D by conv_input_scale, confusingly, and the side_input by
+ // side_input_scale.
+ //
+ // Note that in the special case that side_input_scale is 0, which we infer
+ // from side_input having dimensions [], we skip that addition operation.
+ //
+ // For more information, see
+ // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+ auto& conv_input = op_context.op_info.inputs(0);
+ auto& filter = op_context.op_info.inputs(1);
+ auto& bias = op_context.op_info.inputs(2);
+ auto& side_input = op_context.op_info.inputs(3);
+ auto& conv_input_scale = op_context.op_info.inputs(4);
+ auto& side_input_scale = op_context.op_info.inputs(5);
+
+ // Manually compute our convolution dimensions.
+ bool found_unknown_shapes = false;
+ auto dims = ConvolutionDimensionsFromInputs(
+ conv_input.shape(), filter.shape(), op_context.op_info,
+ &found_unknown_shapes);
+
+ // Construct the shape of our output tensor from our convolution dimensions
+ // and format, as it may not be available yet.
+ //
+ // TODO(varomodt): should we centralize the Conv2D input/output shapes?
+ bool unknown_conv_format = false;
+ OpInfo::TensorProperties output;
+ switch (GetConvolutionFormat(op_context)) {
+ case NCHW:
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ break;
+ case NHWC:
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+ break;
+ default:
+ // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
+ LOG(WARNING) << "unsupported data format: "
+ << GetDataFormat(op_context.op_info)
+ << " Defaulting to NHWC.";
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+ unknown_conv_format = true;
+ break;
+ }
+
+ // Add the operations the fused op always computes.
+ std::vector<OpContext> component_ops = {
+ FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
+ FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
+ FusedChildContext(op_context, "BiasAdd", output, {output, bias}),
+ FusedChildContext(op_context, "Relu", output, {output})};
+
+ // Add our side_input iff it's non-empty.
+ if (side_input.shape().dim_size() > 0) {
+ component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
+ {side_input, side_input_scale}));
+ component_ops.push_back(
+ FusedChildContext(op_context, "Add", output, {side_input, output}));
+ }
+
+ // Construct an op_context which definitely has our output shape.
+ auto op_context_with_output = op_context;
+ op_context_with_output.op_info.mutable_outputs()->Clear();
+ *op_context_with_output.op_info.mutable_outputs()->Add() = output;
+
+ // Construct component operations and run the cost computation.
+ auto costs = PredictFusedOp(op_context_with_output, component_ops);
+ costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ return costs;
+}
+
Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
@@ -1086,6 +1173,66 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
return costs;
}
+Costs OpLevelCostEstimator::PredictFusedOp(
+ const OpContext& op_context,
+ const std::vector<OpContext>& fused_op_contexts) const {
+ // Note that PredictOpCountBasedCost will get the correct memory_time from
+ // the node's inputs and outputs; but we don't want to have to re-implement
+ // the logic for computing the operation count of each of our component
+ // operations here; so we simply add the compute times of each component
+ // operation, then update the execution time.
+ Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+ fused_cost.compute_time = 0;
+ fused_cost.inaccurate = false;
+ for (auto& fused_op : fused_op_contexts) {
+ auto op_cost = PredictCosts(fused_op);
+ fused_cost.compute_time += op_cost.compute_time;
+ fused_cost.inaccurate |= op_cost.inaccurate;
+ }
+
+ CombineCostsAndUpdateExecutionTime(&fused_cost);
+ return fused_cost;
+}
+
+/* static */
+OpContext OpLevelCostEstimator::FusedChildContext(
+ const OpContext& parent, const string& op_name,
+ const OpInfo::TensorProperties& output,
+ const std::vector<OpInfo::TensorProperties>& inputs) {
+ // Setup the base parameters of our new context.
+ OpContext new_context;
+ new_context.name = op_name;
+ new_context.device_name = parent.device_name;
+ new_context.op_info = parent.op_info;
+ new_context.op_info.set_op(op_name);
+
+ // Setup the inputs of our new context.
+ new_context.op_info.mutable_inputs()->Clear();
+ for (const auto& input : inputs) {
+ *new_context.op_info.mutable_inputs()->Add() = input;
+ }
+
+ // Setup the output of our new context.
+ new_context.op_info.mutable_outputs()->Clear();
+ *new_context.op_info.mutable_outputs()->Add() = output;
+
+ return new_context;
+}
+
+/* static */
+OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
+ DataType type, const std::vector<int64>& dims) {
+ OpInfo::TensorProperties ret;
+ ret.set_dtype(type);
+
+ auto shape = ret.mutable_shape();
+ for (const int dim : dims) {
+ shape->add_dim()->set_size(dim);
+ }
+
+ return ret;
+}
+
/* static */
OpLevelCostEstimator::ConvolutionDimensions
OpLevelCostEstimator::OpDimensionsFromInputs(
@@ -1371,6 +1518,21 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
return costs;
}
+/* static */
+OpLevelCostEstimator::ConvolutionFormat
+OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
+ auto data_format = GetDataFormat(op_context.op_info);
+ if (data_format == "NCHW") {
+ return NCHW;
+ } else if (data_format == "NHWC") {
+ return NHWC;
+ } else if (data_format == "NCHW_VECT_C") {
+ return NCHW_VECT_C;
+ }
+
+ return UNKNOWN_CONVOLUTION_FORMAT;
+}
+
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
@@ -1379,6 +1541,5 @@ void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
costs->execution_time = costs->compute_time + costs->memory_time;
}
}
-
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 7080264698..35649f7ee9 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -82,6 +82,13 @@ class OpLevelCostEstimator {
int64 sy; // Stride y.
Padding padding; // SAME or VALID.
};
+ enum ConvolutionFormat {
+ UNKNOWN_CONVOLUTION_FORMAT,
+ NHWC,
+ NCHW,
+ NCHW_VECT_C,
+ NCHW_VECT_W,
+ };
int64 CountConv2DOperations(const OpInfo& op_features,
bool* found_unknown_shapes) const;
int64 CountConv2DOperations(const OpInfo& op_features,
@@ -138,6 +145,7 @@ class OpLevelCostEstimator {
Costs PredictCwiseOp(const OpContext& op_context) const;
Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+ Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
Costs PredictMatMul(const OpContext& op_context) const;
Costs PredictNoOp(const OpContext& op_context) const;
Costs PredictIdentity(const OpContext& op_context) const;
@@ -152,6 +160,10 @@ class OpLevelCostEstimator {
Costs PredictFusedBatchNorm(const OpContext& op_context) const;
Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
+ // Generic cost prediction method for fused operations.
+ Costs PredictFusedOp(const OpContext& op_context,
+ const std::vector<OpContext>& fused_op_contexts) const;
+
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
static double SafeDiv(const double lhs, const double rhs) {
@@ -173,6 +185,20 @@ class OpLevelCostEstimator {
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
+ // Helper to construct child operation contexts for the component operations
+ // of fused ops.
+ static OpContext FusedChildContext(
+ const OpContext& parent, const string& op_name,
+ const OpInfo::TensorProperties& output,
+ const std::vector<OpInfo::TensorProperties>& inputs);
+
+ // Helper to construct tensor shapes.
+ static OpInfo::TensorProperties DescribeTensor(
+ DataType type, const std::vector<int64>& dims);
+
+ // Returns the Conv2D format for this operation.
+ static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
+
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index d797a8a8c1..13ea43bed6 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -93,6 +93,14 @@ OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
return op_context;
}
+// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
+// estimation purposes.
+void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
+ auto shape = tensor->mutable_shape();
+ shape->add_dim()->set_size(dim0);
+ tensor->set_dtype(DT_FLOAT);
+}
+
// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
// estimation purposes.
void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
@@ -120,6 +128,38 @@ OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
return op_context;
}
+// DescribeFusedConv2DBiasActivation constructs an OpContext for a
+// FusedConv2DBiasActivation applied to a convolution input tensor with shape
+// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
+// bias tensor with shape (oz), a side input tensor with shape
+// (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
+// shape (1).
+//
+// Note that this assumes the NHWC data format.
+OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
+ int iz2, int kx, int ky, int ox,
+ int oy, int oz,
+ bool has_side_input) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("FusedConv2DBiasActivation");
+ DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+ DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ DescribeTensor1D(oz, op_context.op_info.add_inputs());
+
+ // Add the side_input, if any.
+ auto side_input = op_context.op_info.add_inputs();
+ if (has_side_input) {
+ DescribeTensor4D(batch, ox, oy, oz, side_input);
+ }
+
+ // Add the scaling tensors.
+ DescribeTensor1D(1, op_context.op_info.add_inputs());
+ DescribeTensor1D(1, op_context.op_info.add_inputs());
+
+ return op_context;
+}
+
// DescribeUnaryOp constructs an OpContext for the given operation applied to
// a 4-tensor with shape (size1, 1, 1, 1).
OpContext DescribeUnaryOp(const string& op, int size1) {
@@ -162,12 +202,9 @@ OpContext DescribeBiasAdd(int size1, int size2) {
op_context.op_info.set_op("BiasAdd");
DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
+ DescribeTensor1D(size1, op_context.op_info.add_inputs());
DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
- auto bias = op_context.op_info.add_inputs();
- bias->mutable_shape()->add_dim()->set_size(size1);
- bias->set_dtype(DT_FLOAT);
-
return op_context;
}
@@ -486,6 +523,25 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
SetComputeMemoryOverlap(false); // Set it back to default.
}
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest,
+ FusedConv2DBiasActivationNoSideInputExecutionTime) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false));
+ EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);