aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc76
1 files changed, 53 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index f363f2915f..76e5c989fc 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -420,7 +420,7 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
DCHECK_LT(0, gflops) << device.DebugString();
DCHECK_LT(0, gb_per_sec) << device.DebugString();
- return {gflops, gb_per_sec};
+ return DeviceInfo(gflops, gb_per_sec);
}
Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
@@ -478,8 +478,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
bool unknown_shapes = false;
const double input_size = CalculateInputSize(op_info, &unknown_shapes);
const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
- const double total_io_bytes = input_size + output_size;
- Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info);
+ Costs costs =
+ PredictOpCountBasedCost(operations, input_size, output_size, op_info);
costs.inaccurate = unknown_shapes;
costs.num_ops_with_unknown_shapes = unknown_shapes;
costs.max_memory = output_size;
@@ -487,9 +487,13 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
}
Costs OpLevelCostEstimator::PredictOpCountBasedCost(
- double operations, double total_io_bytes, const OpInfo& op_info) const {
+ double operations, double input_io_bytes, double output_io_bytes,
+ const OpInfo& op_info) const {
+ double total_io_bytes = input_io_bytes + output_io_bytes;
const DeviceInfo device_info = GetDeviceInfo(op_info.device());
- if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0) {
+ if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
+ device_info.intermediate_read_gb_per_sec <= 0 ||
+ device_info.intermediate_write_gb_per_sec <= 0) {
VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
<< " device type:" << op_info.device().type()
<< " device model:" << op_info.device().model();
@@ -504,9 +508,29 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
<< " Memory Time (ns):" << memory_cost.count();
+ // Check if bytes > 0. If it's not and the bandwidth is set to infinity
+ // then the result would be undefined.
+ double intermediate_read_time =
+ (input_io_bytes > 0)
+ ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
+ : 0;
+
+ double intermediate_write_time =
+ (output_io_bytes > 0)
+ ? std::ceil(output_io_bytes /
+ device_info.intermediate_write_gb_per_sec)
+ : 0;
+
+ Costs::NanoSeconds intermediate_memory_cost(intermediate_read_time +
+ intermediate_write_time);
+ VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
+ << " Intermediate Memory Time (ns):"
+ << intermediate_memory_cost.count();
+
Costs costs;
costs.compute_time = compute_cost;
costs.memory_time = memory_cost;
+ costs.intermediate_memory_time = intermediate_memory_cost;
CombineCostsAndUpdateExecutionTime(&costs);
return costs;
}
@@ -1273,8 +1297,8 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
}
- const double total_io = input_size + output_size;
- Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
+ Costs costs =
+ PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
costs.inaccurate = unknown_shapes;
costs.num_ops_with_unknown_shapes = unknown_shapes;
costs.max_memory = output_size;
@@ -1291,12 +1315,15 @@ Costs OpLevelCostEstimator::PredictFusedOp(
// operations here; so we simply add the compute times of each component
// operation, then update the execution time.
Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+
fused_cost.compute_time = 0;
fused_cost.inaccurate = false;
for (auto& fused_op : fused_op_contexts) {
auto op_cost = PredictCosts(fused_op);
+
fused_cost.compute_time += op_cost.compute_time;
fused_cost.inaccurate |= op_cost.inaccurate;
+ fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
}
CombineCostsAndUpdateExecutionTime(&fused_cost);
@@ -1415,8 +1442,8 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
const double total_output_size =
CalculateOutputSize(op_info, &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1458,8 +1485,8 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad(
const double total_output_size =
CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1491,8 +1518,8 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
const double total_output_size =
CalculateOutputSize(op_info, &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1544,8 +1571,8 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad(
const double total_output_size =
CalculateOutputSize(op_info, &found_unknown_shapes);
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size, op_info);
+ Costs costs = PredictOpCountBasedCost(ops, total_input_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1590,9 +1617,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm(
total_output_size = size_nhwc;
}
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size + total_internal_read_size,
- op_info);
+ Costs costs =
+ PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1624,9 +1651,9 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
double total_internal_read_size = size_nhwc;
double total_output_size = size_nhwc * 1 + size_c * 2;
- Costs costs = PredictOpCountBasedCost(
- ops, total_input_size + total_output_size + total_internal_read_size,
- op_info);
+ Costs costs =
+ PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size,
+ total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
@@ -1637,9 +1664,12 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
- costs->execution_time = std::max(costs->compute_time, costs->memory_time);
+ costs->execution_time =
+ std::max(costs->intermediate_memory_time,
+ std::max(costs->compute_time, costs->memory_time));
} else {
- costs->execution_time = costs->compute_time + costs->memory_time;
+ costs->execution_time = costs->compute_time + costs->memory_time +
+ costs->intermediate_memory_time;
}
}
} // end namespace grappler