diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 17:35:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 17:40:04 -0700 |
commit | f0784e69761ef5b78480e9e8b1fd1aa558186646 (patch) | |
tree | f81c110d2c15b2643a24ace54b78ff00009e7691 | |
parent | eaebeb1d4d939fb9fd0b75e32a76151cb517bfb6 (diff) |
Add support for modeling fast memory close to the processor/gpu
PiperOrigin-RevId: 216453979
4 files changed, 112 insertions, 36 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index 569d9da683..811e923b87 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -31,8 +31,37 @@ constexpr int64 kMemoryUnknown = -1ll; constexpr int64 kZeroMemory = 0ll; struct DeviceInfo { - double gigaops; // Billions of operations executed per second. - double gb_per_sec; // Bandwidth to main memory in GB per second. + // Billions of operations executed per second. + double gigaops; + + // Bandwidth to main memory in GB per second. + double gb_per_sec; + + // Read bandwidth to intermediate memory in GB per second. + double intermediate_read_gb_per_sec; + + // Read bandwidth to intermediate memory in GB per second. + double intermediate_write_gb_per_sec; + + DeviceInfo() + : gigaops(INFINITY), + gb_per_sec(INFINITY), + intermediate_read_gb_per_sec(INFINITY), + intermediate_write_gb_per_sec(INFINITY) {} + + DeviceInfo(const DeviceInfo& input) + : gigaops(input.gigaops), + gb_per_sec(input.gb_per_sec), + intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec), + intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {} + + DeviceInfo(double gigaops, double gb_per_sec, + double intermediate_read_gb_per_sec = INFINITY, + double intermediate_write_gb_per_sec = INFINITY) + : gigaops(gigaops), + gb_per_sec(gb_per_sec), + intermediate_read_gb_per_sec(intermediate_read_gb_per_sec), + intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {} }; // Holds the set of things we might want to estimate or measure in Grappler. @@ -101,6 +130,9 @@ struct Costs { // Memory access cost of running the graph. Duration memory_time; + // Intermediate memory access cost of running the graph + Duration intermediate_memory_time; + // This field can be a very pessimistic estimate of the main memory // requirements of a graph. For example, it might assume that all activations // are live for all of a graph's execution. @@ -146,6 +178,7 @@ Costs::Costs() { execution_time = Duration::zero(); compute_time = Duration::zero(); memory_time = Duration::zero(); + intermediate_memory_time = Duration::zero(); max_memory = kMemoryUnknown; persistent_memory = kMemoryUnknown; temporary_memory = kMemoryUnknown; @@ -158,6 +191,7 @@ Costs Costs::ZeroCosts() { costs.execution_time = Duration::zero(); costs.compute_time = Duration::zero(); costs.memory_time = Duration::zero(); + costs.intermediate_memory_time = Duration::zero(); costs.max_memory = kZeroMemory; costs.persistent_memory = kZeroMemory; costs.temporary_memory = kZeroMemory; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index f363f2915f..76e5c989fc 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -420,7 +420,7 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo( DCHECK_LT(0, gflops) << device.DebugString(); DCHECK_LT(0, gb_per_sec) << device.DebugString(); - return {gflops, gb_per_sec}; + return DeviceInfo(gflops, gb_per_sec); } Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { @@ -478,8 +478,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( bool unknown_shapes = false; const double input_size = CalculateInputSize(op_info, &unknown_shapes); const double output_size = CalculateOutputSize(op_info, &unknown_shapes); - const double total_io_bytes = input_size + output_size; - Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info); + Costs costs = + PredictOpCountBasedCost(operations, input_size, output_size, op_info); costs.inaccurate = unknown_shapes; costs.num_ops_with_unknown_shapes = unknown_shapes; costs.max_memory = output_size; @@ -487,9 +487,13 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( } Costs OpLevelCostEstimator::PredictOpCountBasedCost( - double operations, double total_io_bytes, const OpInfo& op_info) const { + double operations, double input_io_bytes, double output_io_bytes, + const OpInfo& op_info) const { + double total_io_bytes = input_io_bytes + output_io_bytes; const DeviceInfo device_info = GetDeviceInfo(op_info.device()); - if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0) { + if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 || + device_info.intermediate_read_gb_per_sec <= 0 || + device_info.intermediate_write_gb_per_sec <= 0) { VLOG(1) << "BAD DEVICE. Op:" << op_info.op() << " device type:" << op_info.device().type() << " device model:" << op_info.device().model(); @@ -504,9 +508,29 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3 << " Memory Time (ns):" << memory_cost.count(); + // Check if bytes > 0. If it's not and the bandwidth is set to infinity + // then the result would be undefined. + double intermediate_read_time = + (input_io_bytes > 0) + ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec) + : 0; + + double intermediate_write_time = + (output_io_bytes > 0) + ? std::ceil(output_io_bytes / + device_info.intermediate_write_gb_per_sec) + : 0; + + Costs::NanoSeconds intermediate_memory_cost(intermediate_read_time + + intermediate_write_time); + VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3 + << " Intermediate Memory Time (ns):" + << intermediate_memory_cost.count(); + Costs costs; costs.compute_time = compute_cost; costs.memory_time = memory_cost; + costs.intermediate_memory_time = intermediate_memory_cost; CombineCostsAndUpdateExecutionTime(&costs); return costs; } @@ -1273,8 +1297,8 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes); } - const double total_io = input_size + output_size; - Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info); + Costs costs = + PredictOpCountBasedCost(op_count, input_size, output_size, op_info); costs.inaccurate = unknown_shapes; costs.num_ops_with_unknown_shapes = unknown_shapes; costs.max_memory = output_size; @@ -1291,12 +1315,15 @@ Costs OpLevelCostEstimator::PredictFusedOp( // operations here; so we simply add the compute times of each component // operation, then update the execution time. Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info); + fused_cost.compute_time = 0; fused_cost.inaccurate = false; for (auto& fused_op : fused_op_contexts) { auto op_cost = PredictCosts(fused_op); + fused_cost.compute_time += op_cost.compute_time; fused_cost.inaccurate |= op_cost.inaccurate; + fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time; } CombineCostsAndUpdateExecutionTime(&fused_cost); @@ -1415,8 +1442,8 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const { const double total_output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1458,8 +1485,8 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad( const double total_output_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1491,8 +1518,8 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const { const double total_output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1544,8 +1571,8 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad( const double total_output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1590,9 +1617,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm( total_output_size = size_nhwc; } - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size + total_internal_read_size, - op_info); + Costs costs = + PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1624,9 +1651,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( double total_internal_read_size = size_nhwc; double total_output_size = size_nhwc * 1 + size_c * 2; - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size + total_internal_read_size, - op_info); + Costs costs = + PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1637,9 +1664,12 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( Costs* costs) const { if (compute_memory_overlap_) { - costs->execution_time = std::max(costs->compute_time, costs->memory_time); + costs->execution_time = + std::max(costs->intermediate_memory_time, + std::max(costs->compute_time, costs->memory_time)); } else { - costs->execution_time = costs->compute_time + costs->memory_time; + costs->execution_time = costs->compute_time + costs->memory_time + + costs->intermediate_memory_time; } } } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index dd1ee39cb2..84dd9213f7 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -54,7 +54,8 @@ class OpLevelCostEstimator { // Naive cost estimate based on the given operations count and the given total // io size in bytes. Sizes of op_info inputs and outputs are not taken into // consideration. - Costs PredictOpCountBasedCost(double operations, double total_io_bytes, + Costs PredictOpCountBasedCost(double operations, double input_io_bytes, + double output_io_bytes, const OpInfo& op_info) const; // This family of routines counts the number of operations to perform the diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 5b93fb128f..5c5bdad1cb 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -47,6 +47,7 @@ Costs CombineCosts(const Costs& left, const Costs& right) { result.execution_time += right.execution_time; result.compute_time += right.compute_time; result.memory_time += right.memory_time; + result.intermediate_memory_time += right.intermediate_memory_time; result.num_ops_total += right.num_ops_total; if (right.inaccurate) result.inaccurate = true; @@ -825,23 +826,29 @@ Costs VirtualScheduler::Summary() const { VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count(); VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count(); VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count(); + VLOG(1) << "Expected intermediate memory time: " + << graph_costs_.intermediate_memory_time.count(); VLOG(1) << "Expected max memory: " << graph_costs_.max_memory; VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers; VLOG(1) << "Expected max per-op streaming buffers: " << graph_costs_.max_per_op_streaming; - VLOG(1) << "Per-op execution time / compute time / memory time:"; + VLOG(1) << "Per-op execution time / compute time / memory time" + << " / intermediate memory time:"; for (const auto& op_cost_pair : op_to_cost_) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); const auto& compute_cost = op_cost_pair.second.compute_time.count(); const auto& memory_cost = op_cost_pair.second.memory_time.count(); + const auto& intermediate_memory_cost = + op_cost_pair.second.intermediate_memory_time.count(); const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; if (cost) { // Skip printing out zero-cost ops. VLOG(1) << strings::Printf( - " + %30s : %c %10lld / %10lld / %10lld", op.c_str(), + " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(), (is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost), - static_cast<int64>(compute_cost), static_cast<int64>(memory_cost)); + static_cast<int64>(compute_cost), static_cast<int64>(memory_cost), + static_cast<int64>(intermediate_memory_cost)); } } @@ -894,7 +901,8 @@ Costs VirtualScheduler::Summary() const { << " having unknown shapes"; VLOG(1) << "Per-op execution time / compute time / memory time " - "(and memory usage at peak memory usage):"; + << " / intermediate memory time" + << " (and memory usage at peak memory usage):"; // Profile non-persistent op memory usage. for (const auto& node_port : state.mem_usage_snapshot_at_peak) { @@ -910,6 +918,8 @@ Costs VirtualScheduler::Summary() const { const auto& cost = op_cost_pair.second.execution_time.count(); const auto& compute_cost = op_cost_pair.second.compute_time.count(); const auto& memory_cost = op_cost_pair.second.memory_time.count(); + const auto& intermediate_memory_cost = + op_cost_pair.second.intermediate_memory_time.count(); total_compute_time_ns += op_cost_pair.second.execution_time; const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; if (!is_op_cost_accurate) { @@ -927,12 +937,13 @@ Costs VirtualScheduler::Summary() const { : 0.0; if (cost || mem_usage_percent > 1.0) { // Print out only non-zero cost ops or ops with > 1% memory usage. - VLOG(1) << strings::Printf(" + %30s : %c %10lld / %10lld / %10lld", - op.c_str(), - (is_op_cost_accurate ? ' ' : '~'), - static_cast<int64>(cost), - static_cast<int64>(compute_cost), - static_cast<int64>(memory_cost)) + VLOG(1) << strings::Printf( + " + %30s : %c %10lld / %10lld / %10lld / %10lld", + op.c_str(), (is_op_cost_accurate ? ' ' : '~'), + static_cast<int64>(cost), + static_cast<int64>(compute_cost), + static_cast<int64>(memory_cost), + static_cast<int64>(intermediate_memory_cost)) << " (" << strings::HumanReadableNumBytes(op_mem_usage) << " [" << mem_usage_percent << "%] " << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); |