diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-06-29 16:14:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-29 16:19:27 -0700 |
commit | 953a6db9b6c3473f40e6ba2db207c62ef0b19097 (patch) | |
tree | 6bf20f0a838f05d7833b13d88e32746cde0d0b2b /tensorflow/core | |
parent | 8280e0ae9083a65b23608b34723f07e028a56dc8 (diff) |
Improve the accuracy of the cost estimates for the size, shape, and rank ops.
PiperOrigin-RevId: 160587845
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 1 |
2 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 7d3298ded4..7f4cc95f31 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -37,6 +37,9 @@ constexpr char kRecv[] = "_Recv"; constexpr char kBatchMatMul[] = "BatchMatMul"; constexpr char kVariable[] = "Variable"; constexpr char kVariableV2[] = "VariableV2"; +constexpr char kRank[] = "Rank"; +constexpr char kShape[] = "Shape"; +constexpr char kSize[] = "Size"; namespace { @@ -157,7 +160,10 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}}; + {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}, + {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)}, + {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)}, + {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)}}; elementwise_ops_ = { // Unary ops alphabetically sorted @@ -846,5 +852,17 @@ Costs OpLevelCostEstimator::PredictBatchMatMul( return costs; } +Costs OpLevelCostEstimator::PredictMetadata(const OpInfo& op_features) const { + Costs costs; + costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate); + // Metadata operations are so cheap we assume they take the minimum amount of + // time we can represent (1 ns). + costs.execution_time = 1; + costs.compute_time = 1; + costs.memory_time = 0; + + return costs; +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 28d49a7703..59ced70ba6 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -129,6 +129,7 @@ class OpLevelCostEstimator { Costs PredictMatMul(const OpInfo& op_features) const; Costs PredictNoOp(const OpInfo& op_features) const; Costs PredictBatchMatMul(const OpInfo& op_features) const; + Costs PredictMetadata(const OpInfo& op_features) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. |