aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-06-29 16:14:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-29 16:19:27 -0700
commit953a6db9b6c3473f40e6ba2db207c62ef0b19097 (patch)
tree6bf20f0a838f05d7833b13d88e32746cde0d0b2b /tensorflow/core
parent8280e0ae9083a65b23608b34723f07e028a56dc8 (diff)
Improve the accuracy of the cost estimates for the size, shape, and rank ops.
PiperOrigin-RevId: 160587845
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc20
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h1
2 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 7d3298ded4..7f4cc95f31 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -37,6 +37,9 @@ constexpr char kRecv[] = "_Recv";
constexpr char kBatchMatMul[] = "BatchMatMul";
constexpr char kVariable[] = "Variable";
constexpr char kVariableV2[] = "VariableV2";
+constexpr char kRank[] = "Rank";
+constexpr char kShape[] = "Shape";
+constexpr char kSize[] = "Size";
namespace {
@@ -157,7 +160,10 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
+ {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
+ {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
+ {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
+ {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)}};
elementwise_ops_ = {
// Unary ops alphabetically sorted
@@ -846,5 +852,17 @@ Costs OpLevelCostEstimator::PredictBatchMatMul(
return costs;
}
+Costs OpLevelCostEstimator::PredictMetadata(const OpInfo& op_features) const {
+ Costs costs;
+ costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
+ // Metadata operations are so cheap we assume they take the minimum amount of
+ // time we can represent (1 ns).
+ costs.execution_time = 1;
+ costs.compute_time = 1;
+ costs.memory_time = 0;
+
+ return costs;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 28d49a7703..59ced70ba6 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -129,6 +129,7 @@ class OpLevelCostEstimator {
Costs PredictMatMul(const OpInfo& op_features) const;
Costs PredictNoOp(const OpInfo& op_features) const;
Costs PredictBatchMatMul(const OpInfo& op_features) const;
+ Costs PredictMetadata(const OpInfo& op_features) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.