aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 17:35:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 17:40:04 -0700
commitf0784e69761ef5b78480e9e8b1fd1aa558186646 (patch)
treef81c110d2c15b2643a24ace54b78ff00009e7691
parenteaebeb1d4d939fb9fd0b75e32a76151cb517bfb6 (diff)
Add support for modeling fast memory close to the processor/gpu
PiperOrigin-RevId: 216453979
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h38
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc76
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc31
4 files changed, 112 insertions, 36 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index 569d9da683..811e923b87 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -31,8 +31,37 @@ constexpr int64 kMemoryUnknown = -1ll;
constexpr int64 kZeroMemory = 0ll;
struct DeviceInfo {
- double gigaops; // Billions of operations executed per second.
- double gb_per_sec; // Bandwidth to main memory in GB per second.
+ // Billions of operations executed per second.
+ double gigaops;
+
+ // Bandwidth to main memory in GB per second.
+ double gb_per_sec;
+
+ // Read bandwidth to intermediate memory in GB per second.
+ double intermediate_read_gb_per_sec;
+
+ // Read bandwidth to intermediate memory in GB per second.
+ double intermediate_write_gb_per_sec;
+
+ DeviceInfo()
+ : gigaops(INFINITY),
+ gb_per_sec(INFINITY),
+ intermediate_read_gb_per_sec(INFINITY),
+ intermediate_write_gb_per_sec(INFINITY) {}
+
+ DeviceInfo(const DeviceInfo& input)
+ : gigaops(input.gigaops),
+ gb_per_sec(input.gb_per_sec),
+ intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec),
+ intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {}
+
+ DeviceInfo(double gigaops, double gb_per_sec,
+ double intermediate_read_gb_per_sec = INFINITY,
+ double intermediate_write_gb_per_sec = INFINITY)
+ : gigaops(gigaops),
+ gb_per_sec(gb_per_sec),
+ intermediate_read_gb_per_sec(intermediate_read_gb_per_sec),
+ intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {}
};
// Holds the set of things we might want to estimate or measure in Grappler.
@@ -101,6 +130,9 @@ struct Costs {
// Memory access cost of running the graph.
Duration memory_time;
+ // Intermediate memory access cost of running the graph
+ Duration intermediate_memory_time;
+
// This field can be a very pessimistic estimate of the main memory
// requirements of a graph. For example, it might assume that all activations
// are live for all of a graph's execution.
@@ -146,6 +178,7 @@ Costs::Costs() {
execution_time = Duration::zero();
compute_time = Duration::zero();
memory_time = Duration::zero();
+ intermediate_memory_time = Duration::zero();
max_memory = kMemoryUnknown;
persistent_memory = kMemoryUnknown;
temporary_memory = kMemoryUnknown;
@@ -158,6 +191,7 @@ Costs Costs::ZeroCosts() {
costs.execution_time = Duration::zero();
costs.compute_time = Duration::zero();
costs.memory_time = Duration::zero();
+ costs.intermediate_memory_time = Duration::zero();
costs.max_memory = kZeroMemory;
costs.persistent_memory = kZeroMemory;
costs.temporary_memory = kZeroMemory;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index f363f2915f..76e5c989fc 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -420,7 +420,7 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
DCHECK_LT(0, gflops) << device.DebugString();
DCHECK_LT(0, gb_per_sec) << device.DebugString();
- return {gflops, gb_per_sec};
+ return DeviceInfo(gflops, gb_per_sec);
}
Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
@@ -478,8 +478,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
bool unknown_shapes = false;
const double input_size = CalculateInputSize(op_info, &unknown_shapes);
const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
- const double total_io_bytes = input_size + output_size;
- Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info);
+ Costs costs =
+ PredictOpCountBasedCost(operations, input_size, output_size, op_info);
costs.inaccurate = unknown_shapes;
costs.num_ops_with_unknown_shapes = unknown_shapes;
costs.max_memory = output_size;
@@ -487,9 +487,13 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
}
Costs OpLevelCostEstimator::PredictOpCountBasedCost(
- double operations, double total_io_bytes, const OpInfo& op_info) const {
+ double operations, double input_io_bytes, double output_io_bytes,
+ const OpInfo& op_info) const {
+ double total_io_bytes = input_io_bytes + output_io_bytes;
const DeviceInfo device_info = GetDeviceInfo(op_info.device());
- if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0) {
+ if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
+ device_info.intermediate_read_gb_per_sec <= 0 ||
+ device_info.intermediate_write_gb_per_sec <= 0) {
VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
<< " device type:" << op_info.device().type()
<< " device model:" << op_info.device().model();
@@ -504,9 +508,29 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
<< " Memory Time (ns):" << memory_cost.count();
+ // Check if bytes > 0. If it's not and the bandwidth is set to infinity
+ // then the result would be undefined.
+ double intermediate_read_time =
+ (input_io_bytes > 0)
+ ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
+ : 0;
+
+ double intermediate_write_time =
+ (output_io_bytes > 0)
+ ? std::ceil(output_io_bytes /
+ device_info.intermediate_write_gb_per_sec)
+ : 0;
+
+ Costs::NanoSeconds intermediate_memory_cost(intermediate_read_time +
+ intermediate_write_time);
+ VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
+ << " Intermediate Memory Time (ns):"
+ << intermediate_memory_cost.count();
+
Costs costs;
costs.compute_time = compute_cost;
costs.memory_time = memory_cost;
+ costs.intermediate_memory_time = intermediate_memory_cost;
CombineCostsAndUpdateExecutionTime(&costs);
return costs;
}
@@ -1273,8 +1297,8 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
}
- const double total_io = input_size + output_size;
- Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
+ Costs costs =
+ PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
costs.inaccurate = unknown_shapes;
costs.num_ops_with_unknown_shapes = unknown_shapes;
costs.max_memory = output_size;
@@ -1291,12 +1315,15 @@ Costs OpLevelCostEstimator::PredictFusedOp(
// operations here; so we simply add the compute times of each component
// operation, then update the execution time.
Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+
fused_cost.compute_time = 0;
fused_cost.inaccurate = false;
for (auto& fused_op : fused_op_contexts) {
auto op_cost = PredictCosts(fused_op);
+
fused_cost.compute_time += op_cost.compute_time;
fused_cost.inaccurate |= op_cost.inaccurate;
+ fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
}
CombineCostsAndUpdateExecutionTime(&fused_cost);
@@ -1415,8 +1442,8 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
const double total_output_size =
CalculateOutputSize(op_info, &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1458,8 +1485,8 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad(
const double total_output_size =
CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1491,8 +1518,8 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
const double total_output_size =
CalculateOutputSize(op_info, &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1544,8 +1571,8 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad(
const double total_output_size =
CalculateOutputSize(op_info, &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1590,9 +1617,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm(
total_output_size = size_nhwc;
}
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size + total_internal_read_size,
- op_info);
+ Costs costs =
+ PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1624,9 +1651,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
double total_internal_read_size = size_nhwc;
double total_output_size = size_nhwc * 1 + size_c * 2;
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size + total_internal_read_size,
- op_info);
+ Costs costs =
+ PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1637,9 +1664,12 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
- costs->execution_time = std::max(costs->compute_time, costs->memory_time);
+ costs->execution_time =
+ std::max(costs->intermediate_memory_time,
+ std::max(costs->compute_time, costs->memory_time));
} else {
- costs->execution_time = costs->compute_time + costs->memory_time;
+ costs->execution_time = costs->compute_time + costs->memory_time +
+ costs->intermediate_memory_time;
}
}
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index dd1ee39cb2..84dd9213f7 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -54,7 +54,8 @@ class OpLevelCostEstimator {
// Naive cost estimate based on the given operations count and the given total
// io size in bytes. Sizes of op_info inputs and outputs are not taken into
// consideration.
- Costs PredictOpCountBasedCost(double operations, double total_io_bytes,
+ Costs PredictOpCountBasedCost(double operations, double input_io_bytes,
+ double output_io_bytes,
const OpInfo& op_info) const;
// This family of routines counts the number of operations to perform the
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 5b93fb128f..5c5bdad1cb 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -47,6 +47,7 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
result.execution_time += right.execution_time;
result.compute_time += right.compute_time;
result.memory_time += right.memory_time;
+ result.intermediate_memory_time += right.intermediate_memory_time;
result.num_ops_total += right.num_ops_total;
if (right.inaccurate) result.inaccurate = true;
@@ -825,23 +826,29 @@ Costs VirtualScheduler::Summary() const {
VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
+ VLOG(1) << "Expected intermediate memory time: "
+ << graph_costs_.intermediate_memory_time.count();
VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
VLOG(1) << "Expected max per-op streaming buffers: "
<< graph_costs_.max_per_op_streaming;
- VLOG(1) << "Per-op execution time / compute time / memory time:";
+ VLOG(1) << "Per-op execution time / compute time / memory time"
+ << " / intermediate memory time:";
for (const auto& op_cost_pair : op_to_cost_) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
const auto& compute_cost = op_cost_pair.second.compute_time.count();
const auto& memory_cost = op_cost_pair.second.memory_time.count();
+ const auto& intermediate_memory_cost =
+ op_cost_pair.second.intermediate_memory_time.count();
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (cost) { // Skip printing out zero-cost ops.
VLOG(1) << strings::Printf(
- " + %30s : %c %10lld / %10lld / %10lld", op.c_str(),
+ " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
(is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
- static_cast<int64>(compute_cost), static_cast<int64>(memory_cost));
+ static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
+ static_cast<int64>(intermediate_memory_cost));
}
}
@@ -894,7 +901,8 @@ Costs VirtualScheduler::Summary() const {
<< " having unknown shapes";
VLOG(1) << "Per-op execution time / compute time / memory time "
- "(and memory usage at peak memory usage):";
+ << " / intermediate memory time"
+ << " (and memory usage at peak memory usage):";
// Profile non-persistent op memory usage.
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
@@ -910,6 +918,8 @@ Costs VirtualScheduler::Summary() const {
const auto& cost = op_cost_pair.second.execution_time.count();
const auto& compute_cost = op_cost_pair.second.compute_time.count();
const auto& memory_cost = op_cost_pair.second.memory_time.count();
+ const auto& intermediate_memory_cost =
+ op_cost_pair.second.intermediate_memory_time.count();
total_compute_time_ns += op_cost_pair.second.execution_time;
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (!is_op_cost_accurate) {
@@ -927,12 +937,13 @@ Costs VirtualScheduler::Summary() const {
: 0.0;
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
- VLOG(1) << strings::Printf(" + %30s : %c %10lld / %10lld / %10lld",
- op.c_str(),
- (is_op_cost_accurate ? ' ' : '~'),
- static_cast<int64>(cost),
- static_cast<int64>(compute_cost),
- static_cast<int64>(memory_cost))
+ VLOG(1) << strings::Printf(
+ " + %30s : %c %10lld / %10lld / %10lld / %10lld",
+ op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
+ static_cast<int64>(cost),
+ static_cast<int64>(compute_cost),
+ static_cast<int64>(memory_cost),
+ static_cast<int64>(intermediate_memory_cost))
<< " (" << strings::HumanReadableNumBytes(op_mem_usage) << " ["
<< mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");