aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-26 14:02:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 14:06:25 -0800
commitea25bf9558a79747d82220e493d4347901853976 (patch)
treeeda168812562c1233f7e939d113d03bad94fff16
parent1bd5b2e6fada5334b04e2db87cf246d1a83fa533 (diff)
Add op level memory usage estimation to the op_level_cost_estimator
PiperOrigin-RevId: 183441321
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc2
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h6
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc58
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h2
4 files changed, 52 insertions, 16 deletions
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
index 1c2c171383..f241922471 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
@@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
Costs summary;
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary));
- EXPECT_EQ(Costs::NanoSeconds(9150), summary.execution_time);
+ EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time);
// Make this estimate accurate:
// TODO(http://b/70031255): Accurate estimator for RandomUniform op needed
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index d442861339..9e01ec5ff5 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -100,6 +100,8 @@ struct Costs {
// requirements of a graph. For example, it might assume that all activations
// are live for all of a graph's execution.
int64 max_memory; // Maximum main memory requirement in bytes over all ops.
+ int64 persistent_memory;
+ int64 temporary_memory;
// These fields are used for TPU-related estimations. They are per-op
// maximums, so each op is evaluated independently, but we want the maximum of
@@ -132,6 +134,8 @@ Costs::Costs() {
compute_time = Duration::zero();
memory_time = Duration::zero();
max_memory = kMemoryUnknown;
+ persistent_memory = kMemoryUnknown;
+ temporary_memory = kMemoryUnknown;
max_per_op_buffers = kMemoryUnknown;
max_per_op_streaming = kMemoryUnknown;
}
@@ -142,6 +146,8 @@ Costs Costs::ZeroCosts() {
costs.compute_time = Duration::zero();
costs.memory_time = Duration::zero();
costs.max_memory = kZeroMemory;
+ costs.persistent_memory = kZeroMemory;
+ costs.temporary_memory = kZeroMemory;
costs.max_per_op_buffers = kZeroMemory;
costs.max_per_op_streaming = kZeroMemory;
return costs;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6bc136a3f8..cf317374cf 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -47,6 +47,8 @@ constexpr char kSize[] = "Size";
constexpr char kStopGradient[] = "StopGradient";
constexpr char kPreventGradient[] = "PreventGradient";
+static const Costs::Duration kMinComputeTime(1);
+
namespace {
string GetDataFormat(const OpInfo& op_features) {
@@ -163,18 +165,20 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
- {kPlaceholder, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kRefIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kStopGradient, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kPreventGradient, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kSend, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kConst, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
+
+ {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
+
+ {kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
+ {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
+ {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
@@ -429,6 +433,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
costs.execution_time = compute_cost + memory_cost;
}
costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
return costs;
}
@@ -885,6 +890,30 @@ Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
return Costs::ZeroCosts();
}
+Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
+ VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
+ Costs result = Costs::ZeroCosts();
+ result.max_memory = CalculateOutputSize(op_features, &result.inaccurate);
+ // Assign the minimum amount of time we can represent to the identity op since
+ // it tends to be really cheap.
+ result.compute_time = kMinComputeTime;
+ result.execution_time = result.compute_time;
+ return result;
+}
+
+Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
+ VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
+ Costs result = Costs::ZeroCosts();
+ result.persistent_memory =
+ CalculateOutputSize(op_features, &result.inaccurate);
+
+ result.compute_time = kMinComputeTime;
+ result.execution_time = result.execution_time;
+ return result;
+}
+
Costs OpLevelCostEstimator::PredictBatchMatMul(
const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
@@ -898,13 +927,12 @@ Costs OpLevelCostEstimator::PredictBatchMatMul(
Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
- Costs costs;
+ Costs costs = Costs::ZeroCosts();
costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
// Metadata operations are so cheap we assume they take the minimum amount of
// time we can represent (1 ns).
- costs.execution_time = 1;
- costs.compute_time = 1;
- costs.memory_time = 0;
+ costs.compute_time = kMinComputeTime;
+ costs.execution_time = costs.compute_time;
return costs;
}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 5f541ccf04..a292e5e97f 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -132,6 +132,8 @@ class OpLevelCostEstimator {
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
Costs PredictMatMul(const OpContext& op_context) const;
Costs PredictNoOp(const OpContext& op_context) const;
+ Costs PredictIdentity(const OpContext& op_context) const;
+ Costs PredictVariable(const OpContext& op_context) const;
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;