aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
authorGravatar Rob Sloan <varomodt@google.com>2018-04-06 21:55:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 21:57:46 -0700
commit30e2b97897d05e47b457ab1d5d0d9c4227b87845 (patch)
tree19ccb01faec4fc451cfe45867d130277a9116fe7 /tensorflow/core/grappler/costs/op_level_cost_estimator.h
parent273495dc2c957402f832cae31a438e550db2b7f0 (diff)
Add analytical cost model for FusedConv2DBiasActivation.
PiperOrigin-RevId: 191978272
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 7080264698..35649f7ee9 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -82,6 +82,13 @@ class OpLevelCostEstimator {
int64 sy; // Stride y.
Padding padding; // SAME or VALID.
};
+ enum ConvolutionFormat {
+ UNKNOWN_CONVOLUTION_FORMAT,
+ NHWC,
+ NCHW,
+ NCHW_VECT_C,
+ NCHW_VECT_W,
+ };
int64 CountConv2DOperations(const OpInfo& op_features,
bool* found_unknown_shapes) const;
int64 CountConv2DOperations(const OpInfo& op_features,
@@ -138,6 +145,7 @@ class OpLevelCostEstimator {
Costs PredictCwiseOp(const OpContext& op_context) const;
Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+ Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
Costs PredictMatMul(const OpContext& op_context) const;
Costs PredictNoOp(const OpContext& op_context) const;
Costs PredictIdentity(const OpContext& op_context) const;
@@ -152,6 +160,10 @@ class OpLevelCostEstimator {
Costs PredictFusedBatchNorm(const OpContext& op_context) const;
Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
+ // Generic cost prediction method for fused operations.
+ Costs PredictFusedOp(const OpContext& op_context,
+ const std::vector<OpContext>& fused_op_contexts) const;
+
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
static double SafeDiv(const double lhs, const double rhs) {
@@ -173,6 +185,20 @@ class OpLevelCostEstimator {
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
+ // Helper to construct child operation contexts for the component operations
+ // of fused ops.
+ static OpContext FusedChildContext(
+ const OpContext& parent, const string& op_name,
+ const OpInfo::TensorProperties& output,
+ const std::vector<OpInfo::TensorProperties>& inputs);
+
+ // Helper to construct tensor shapes.
+ static OpInfo::TensorProperties DescribeTensor(
+ DataType type, const std::vector<int64>& dims);
+
+ // Returns the Conv2D format for this operation.
+ static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
+
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.