aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-21 14:07:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 14:11:06 -0700
commit74057361032bc4d9a9fbfced1f433b06c06c09ec (patch)
treebf6453e743bcd002619166a525bcf9907388b1e1 /tensorflow/core/grappler/costs
parent51cbb58ca5147218b3995dc124bd92927d93e913 (diff)
Add option to use compute_memory_overlap; if true, use max of memory_cost and compute_cost, instead of sum for op level cost in analytical cost estimator.
PiperOrigin-RevId: 162782658
Diffstat (limited to 'tensorflow/core/grappler/costs')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc9
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc14
3 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 5148a4b99e..f13b426b3c 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -268,6 +268,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
Eigen::internal::scalar_quotient_op<float>>::Cost},
{"TruncateMod", Eigen::internal::functor_traits<
Eigen::internal::scalar_mod_op<float>>::Cost}};
+
+ // By default, use sum of memory_time and compute_time for execution_time.
+ compute_memory_overlap_ = false;
}
Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const {
@@ -395,7 +398,11 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
Costs costs;
costs.compute_time = compute_cost;
costs.memory_time = memory_cost;
- costs.execution_time = compute_cost + memory_cost;
+ if (compute_memory_overlap_) {
+ costs.execution_time = std::max(compute_cost, memory_cost);
+ } else {
+ costs.execution_time = compute_cost + memory_cost;
+ }
costs.inaccurate = found_unknown_shapes;
return costs;
}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 59ced70ba6..36ef6a5c61 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -150,6 +150,9 @@ class OpLevelCostEstimator {
std::map<string, int> elementwise_ops_;
typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
std::map<string, CostImpl> device_cost_impl_;
+ // If true, assume compute and memory overlap; hence, the op cost is max of
+ // compute_time and memory_time, insteaf of sum of those two.
+ bool compute_memory_overlap_;
private:
friend class OpLevelCostEstimatorTest;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 1f0e02c160..0cbfb10017 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -157,6 +157,10 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
found_unknown_shapes);
}
+ void SetComputeMemoryOverlap(bool value) {
+ estimator_.compute_memory_overlap_ = value;
+ }
+
OpLevelCostEstimator estimator_;
};
@@ -168,6 +172,16 @@ TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
EXPECT_TRUE(cost.inaccurate);
}
+TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
+ SetComputeMemoryOverlap(true);
+ auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1));
+ EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(200), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(2000), cost.execution_time); // max(2000, 200)
+ EXPECT_TRUE(cost.inaccurate);
+ SetComputeMemoryOverlap(false); // Set it back to default.
+}
+
TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
auto cost = PredictCosts(DescribeOp("Mul", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);