aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/cost_estimator.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h38
1 files changed, 36 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index 569d9da683..811e923b87 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -31,8 +31,37 @@ constexpr int64 kMemoryUnknown = -1ll;
constexpr int64 kZeroMemory = 0ll;
struct DeviceInfo {
- double gigaops; // Billions of operations executed per second.
- double gb_per_sec; // Bandwidth to main memory in GB per second.
+ // Billions of operations executed per second.
+ double gigaops;
+
+ // Bandwidth to main memory in GB per second.
+ double gb_per_sec;
+
+ // Read bandwidth to intermediate memory in GB per second.
+ double intermediate_read_gb_per_sec;
+
+ // Read bandwidth to intermediate memory in GB per second.
+ double intermediate_write_gb_per_sec;
+
+ DeviceInfo()
+ : gigaops(INFINITY),
+ gb_per_sec(INFINITY),
+ intermediate_read_gb_per_sec(INFINITY),
+ intermediate_write_gb_per_sec(INFINITY) {}
+
+ DeviceInfo(const DeviceInfo& input)
+ : gigaops(input.gigaops),
+ gb_per_sec(input.gb_per_sec),
+ intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec),
+ intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {}
+
+ DeviceInfo(double gigaops, double gb_per_sec,
+ double intermediate_read_gb_per_sec = INFINITY,
+ double intermediate_write_gb_per_sec = INFINITY)
+ : gigaops(gigaops),
+ gb_per_sec(gb_per_sec),
+ intermediate_read_gb_per_sec(intermediate_read_gb_per_sec),
+ intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {}
};
// Holds the set of things we might want to estimate or measure in Grappler.
@@ -101,6 +130,9 @@ struct Costs {
// Memory access cost of running the graph.
Duration memory_time;
+ // Intermediate memory access cost of running the graph
+ Duration intermediate_memory_time;
+
// This field can be a very pessimistic estimate of the main memory
// requirements of a graph. For example, it might assume that all activations
// are live for all of a graph's execution.
@@ -146,6 +178,7 @@ Costs::Costs() {
execution_time = Duration::zero();
compute_time = Duration::zero();
memory_time = Duration::zero();
+ intermediate_memory_time = Duration::zero();
max_memory = kMemoryUnknown;
persistent_memory = kMemoryUnknown;
temporary_memory = kMemoryUnknown;
@@ -158,6 +191,7 @@ Costs Costs::ZeroCosts() {
costs.execution_time = Duration::zero();
costs.compute_time = Duration::zero();
costs.memory_time = Duration::zero();
+ costs.intermediate_memory_time = Duration::zero();
costs.max_memory = kZeroMemory;
costs.persistent_memory = kZeroMemory;
costs.temporary_memory = kZeroMemory;