diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 76 |
1 files changed, 53 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index f363f2915f..76e5c989fc 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -420,7 +420,7 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo( DCHECK_LT(0, gflops) << device.DebugString(); DCHECK_LT(0, gb_per_sec) << device.DebugString(); - return {gflops, gb_per_sec}; + return DeviceInfo(gflops, gb_per_sec); } Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { @@ -478,8 +478,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( bool unknown_shapes = false; const double input_size = CalculateInputSize(op_info, &unknown_shapes); const double output_size = CalculateOutputSize(op_info, &unknown_shapes); - const double total_io_bytes = input_size + output_size; - Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info); + Costs costs = + PredictOpCountBasedCost(operations, input_size, output_size, op_info); costs.inaccurate = unknown_shapes; costs.num_ops_with_unknown_shapes = unknown_shapes; costs.max_memory = output_size; @@ -487,9 +487,13 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( } Costs OpLevelCostEstimator::PredictOpCountBasedCost( - double operations, double total_io_bytes, const OpInfo& op_info) const { + double operations, double input_io_bytes, double output_io_bytes, + const OpInfo& op_info) const { + double total_io_bytes = input_io_bytes + output_io_bytes; const DeviceInfo device_info = GetDeviceInfo(op_info.device()); - if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0) { + if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 || + device_info.intermediate_read_gb_per_sec <= 0 || + device_info.intermediate_write_gb_per_sec <= 0) { VLOG(1) << "BAD DEVICE. Op:" << op_info.op() << " device type:" << op_info.device().type() << " device model:" << op_info.device().model(); @@ -504,9 +508,29 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3 << " Memory Time (ns):" << memory_cost.count(); + // Check if bytes > 0. If it's not and the bandwidth is set to infinity + // then the result would be undefined. + double intermediate_read_time = + (input_io_bytes > 0) + ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec) + : 0; + + double intermediate_write_time = + (output_io_bytes > 0) + ? std::ceil(output_io_bytes / + device_info.intermediate_write_gb_per_sec) + : 0; + + Costs::NanoSeconds intermediate_memory_cost(intermediate_read_time + + intermediate_write_time); + VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3 + << " Intermediate Memory Time (ns):" + << intermediate_memory_cost.count(); + Costs costs; costs.compute_time = compute_cost; costs.memory_time = memory_cost; + costs.intermediate_memory_time = intermediate_memory_cost; CombineCostsAndUpdateExecutionTime(&costs); return costs; } @@ -1273,8 +1297,8 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes); } - const double total_io = input_size + output_size; - Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info); + Costs costs = + PredictOpCountBasedCost(op_count, input_size, output_size, op_info); costs.inaccurate = unknown_shapes; costs.num_ops_with_unknown_shapes = unknown_shapes; costs.max_memory = output_size; @@ -1291,12 +1315,15 @@ Costs OpLevelCostEstimator::PredictFusedOp( // operations here; so we simply add the compute times of each component // operation, then update the execution time. Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info); + fused_cost.compute_time = 0; fused_cost.inaccurate = false; for (auto& fused_op : fused_op_contexts) { auto op_cost = PredictCosts(fused_op); + fused_cost.compute_time += op_cost.compute_time; fused_cost.inaccurate |= op_cost.inaccurate; + fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time; } CombineCostsAndUpdateExecutionTime(&fused_cost); @@ -1415,8 +1442,8 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const { const double total_output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1458,8 +1485,8 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad( const double total_output_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1491,8 +1518,8 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const { const double total_output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1544,8 +1571,8 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad( const double total_output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size, op_info); + Costs costs = PredictOpCountBasedCost(ops, total_input_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1590,9 +1617,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm( total_output_size = size_nhwc; } - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size + total_internal_read_size, - op_info); + Costs costs = + PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1624,9 +1651,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( double total_internal_read_size = size_nhwc; double total_output_size = size_nhwc * 1 + size_c * 2; - Costs costs = PredictOpCountBasedCost( - ops, total_input_size + total_output_size + total_internal_read_size, - op_info); + Costs costs = + PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size, + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; @@ -1637,9 +1664,12 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( Costs* costs) const { if (compute_memory_overlap_) { - costs->execution_time = std::max(costs->compute_time, costs->memory_time); + costs->execution_time = + std::max(costs->intermediate_memory_time, + std::max(costs->compute_time, costs->memory_time)); } else { - costs->execution_time = costs->compute_time + costs->memory_time; + costs->execution_time = costs->compute_time + costs->memory_time + + costs->intermediate_memory_time; } } } // end namespace grappler |