aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc48
1 files changed, 8 insertions, 40 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 037a823096..5b93fb128f 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -473,6 +473,7 @@ Status VirtualScheduler::Init() {
VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
<< str_util::Join(feed_nodes, ",");
}
+
initialized_ = true;
return Status::OK();
}
@@ -695,38 +696,6 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
return it->second;
}
-int64 VirtualScheduler::CalculateOutputSize(
- const std::vector<OpInfo::TensorProperties>& output_properties,
- const int port_num) const {
- if (port_num < 0) {
- return 4; // 4B for control dependency.
- }
-
- if (port_num >= output_properties.size()) {
- VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
- << "port_num: " << port_num
- << " >= output_properties.size(): " << output_properties.size();
- return 0;
- }
-
- const auto& output = output_properties[port_num];
- int64 output_size = DataTypeSize(BaseType(output.dtype()));
-
- for (const auto& dim : output.shape().dim()) {
- auto dim_size = dim.size();
- if (dim_size < 0) {
- // Zero output size if there's any unknown dim.
- output_size = 0;
- VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
- << "unknown dim: " << output_size;
- break;
- }
- output_size *= dim_size;
- }
-
- return output_size;
-}
-
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
std::map<string, Costs>* op_cost) {
auto it = op_cost->find(op_name);
@@ -744,7 +713,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
const NodeDef* node = ready_nodes_->GetCurrNode();
const string& op_name = node->op();
- // Also keep track of op counts and times per op (with their shapes).
+ auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
+ op_cost = CombineCosts(op_cost, node_costs);
+
+ // Also keep track of op counts and costs per op (with their shapes).
OpContext op_context = GetCurrNode();
string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1;
@@ -752,9 +724,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
!node_costs.inaccurate);
- auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
- op_cost = CombineCosts(op_cost, node_costs);
-
// Update node and device states.
auto& node_state = node_map_[node];
auto& device = device_[node_state.device_name];
@@ -795,7 +764,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count();
- // Increment num_inputs_ready of the output nodes
+ // Increment num_inputs_ready of the output nodes and maybe add to ready nodes
for (const auto& port_num_output_pair : node_state.outputs) {
for (auto* output_node : port_num_output_pair.second) {
auto& output_state = node_map_[output_node];
@@ -812,7 +781,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
}
}
- // Increment num_outputs_executed of the input nodes.
+ // Increment num_outputs_executed of the input nodes and maybe update memory.
for (const auto& input_port : node_state.inputs) {
auto* input = input_port.first;
auto port = input_port.second;
@@ -841,7 +810,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
}
}
- // Remove the current node; assume FIFO.
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty();
@@ -1007,7 +975,7 @@ Costs VirtualScheduler::Summary(RunMetadata* metadata) {
return Summary();
}
- // Fill RunMetadata.
+ // Fill RunMetadata's step_stats and partition_graphs fields.
StepStats* stepstats = metadata->mutable_step_stats();
for (const auto& device : device_) {
GraphDef* device_partition_graph = metadata->add_partition_graphs();