diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 48 |
1 files changed, 8 insertions, 40 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 037a823096..5b93fb128f 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -473,6 +473,7 @@ Status VirtualScheduler::Init() { VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: " << str_util::Join(feed_nodes, ","); } + initialized_ = true; return Status::OK(); } @@ -695,38 +696,6 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { return it->second; } -int64 VirtualScheduler::CalculateOutputSize( - const std::vector<OpInfo::TensorProperties>& output_properties, - const int port_num) const { - if (port_num < 0) { - return 4; // 4B for control dependency. - } - - if (port_num >= output_properties.size()) { - VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " - << "port_num: " << port_num - << " >= output_properties.size(): " << output_properties.size(); - return 0; - } - - const auto& output = output_properties[port_num]; - int64 output_size = DataTypeSize(BaseType(output.dtype())); - - for (const auto& dim : output.shape().dim()) { - auto dim_size = dim.size(); - if (dim_size < 0) { - // Zero output size if there's any unknown dim. - output_size = 0; - VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " - << "unknown dim: " << output_size; - break; - } - output_size *= dim_size; - } - - return output_size; -} - Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, std::map<string, Costs>* op_cost) { auto it = op_cost->find(op_name); @@ -744,7 +713,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { const NodeDef* node = ready_nodes_->GetCurrNode(); const string& op_name = node->op(); - // Also keep track of op counts and times per op (with their shapes). + auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); + op_cost = CombineCosts(op_cost, node_costs); + + // Also keep track of op counts and costs per op (with their shapes). OpContext op_context = GetCurrNode(); string node_description = GetOpDescription(op_context.op_info); op_counts_[node_description] += 1; @@ -752,9 +724,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { std::make_pair(node_costs.execution_time.asMicroSeconds().count(), !node_costs.inaccurate); - auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); - op_cost = CombineCosts(op_cost, node_costs); - // Update node and device states. auto& node_state = node_map_[node]; auto& device = device_[node_state.device_name]; @@ -795,7 +764,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { << ", scheduled: " << node_state.time_scheduled.count() << ", finished: " << node_state.time_finished.count(); - // Increment num_inputs_ready of the output nodes + // Increment num_inputs_ready of the output nodes and maybe add to ready nodes for (const auto& port_num_output_pair : node_state.outputs) { for (auto* output_node : port_num_output_pair.second) { auto& output_state = node_map_[output_node]; @@ -812,7 +781,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { } } - // Increment num_outputs_executed of the input nodes. + // Increment num_outputs_executed of the input nodes and maybe update memory. for (const auto& input_port : node_state.inputs) { auto* input = input_port.first; auto port = input_port.second; @@ -841,7 +810,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { } } - // Remove the current node; assume FIFO. ready_nodes_->RemoveCurrNode(); return !ready_nodes_->Empty(); @@ -1007,7 +975,7 @@ Costs VirtualScheduler::Summary(RunMetadata* metadata) { return Summary(); } - // Fill RunMetadata. + // Fill RunMetadata's step_stats and partition_graphs fields. StepStats* stepstats = metadata->mutable_step_stats(); for (const auto& device : device_) { GraphDef* device_partition_graph = metadata->add_partition_graphs(); |