diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-06 16:45:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-06 16:49:04 -0700 |
commit | 8f89b654f4d49a1b5d4462303ef27f7f7a2958b3 (patch) | |
tree | 4be61a6867376b086c1cb0b770f38c408df91b69 | |
parent | 0ea0bf5aae2961be4edbe00c205bed01d293dce3 (diff) |
Profile memory usage in VirtualScheduler and report peak memory usage.
To do so, NodeState now handles different output ports of a node (in case
a node has multiple outputs).
Also, VirtualScheduler code is cleaned up with more comments.
PiperOrigin-RevId: 158209068
-rw-r--r-- | tensorflow/core/grappler/costs/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 425 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.h | 123 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler_test.cc | 450 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 3 |
7 files changed, 827 insertions, 202 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 8455c465df..1f90694331 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -176,6 +176,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:utils", "//tensorflow/core/grappler/costs:cost_estimator", @@ -192,6 +193,10 @@ cc_test( "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:utils", "//tensorflow/core/grappler/clusters:virtual_cluster", ], ) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 11a57921e5..75ff75123e 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -31,6 +31,8 @@ constexpr char kNoOp[] = "NoOp"; constexpr char kReshape[] = "Reshape"; constexpr char kRecv[] = "_Recv"; constexpr char kBatchMatMul[] = "BatchMatMul"; +constexpr char kVariable[] = "Variable"; +constexpr char kVariableV2[] = "VariableV2"; OpLevelCostEstimator::OpLevelCostEstimator() { // Syntactic sugar to build and return a lambda that takes an OpInfo and @@ -53,6 +55,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}}; } @@ -567,7 +571,7 @@ int64 OpLevelCostEstimator::CalculateSingleInputSize( for (const auto& dim : input_shape.dim()) { input_size *= dim.size(); } - return input_size * DataTypeSize(input.dtype()); + return input_size * DataTypeSize(BaseType(input.dtype())); } int64 OpLevelCostEstimator::CalculateInputSize( @@ -589,7 +593,7 @@ int64 OpLevelCostEstimator::CalculateOutputSize( for (const auto& output : op_features.outputs()) { DataType dt = output.dtype(); const auto& original_output_shape = output.shape(); - int64 output_size = DataTypeSize(dt); + int64 output_size = DataTypeSize(BaseType(dt)); int num_dims = std::max(1, original_output_shape.dim_size()); auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims, found_unknown_shapes); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 8d8d246078..dfa8c768e7 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/util/device_name_utils.h" @@ -55,10 +56,10 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, const bool use_static_shapes, const string& default_device_type, Cluster* cluster, VirtualPlacer* placer) - : graph_properties_(*grappler_item), - graph_costs_(Costs::ZeroCosts()), - // TODO(dyoon): Use a better way than FIFO. + : // TODO(dyoon): Use a better way than FIFO. ready_nodes_(new FIFOManager()), + graph_costs_(Costs::ZeroCosts()), + graph_properties_(*grappler_item), cluster_(cluster), grappler_item_(grappler_item), use_static_shapes_(use_static_shapes), @@ -68,6 +69,11 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, } Status VirtualScheduler::Init() { + // Init() preprocesses the input grappler_item and graph_properties to extract + // necessary information for emulating tensorflow op scheduling and + // construct internal data structures (NodeState and DeviceState) for virtual + // scheduling. + // Construct graph properties. Status status; if (use_static_shapes_) { @@ -82,13 +88,12 @@ Status VirtualScheduler::Init() { const auto& graph = grappler_item_->graph; const auto& fetch_nodes = grappler_item_->fetch; - // First, get the nodes that would run to output fetch_nodes. + // Get the nodes that would run to output fetch_nodes. std::vector<const NodeDef*> nodes = ComputeTransitiveFanin(graph, fetch_nodes); // TODO(dyoon): this is a bit inefficient as name_to_node is already built in // ComputeTransitiveFanin(). - // // Once ComputeTransitiveFanin is complete, only the nodes that can be reached // from the fetch nodes are scheduled. So the scheduled nodes should be // exactly the same as those executed for real. One possible discrepancy could @@ -98,61 +103,72 @@ Status VirtualScheduler::Init() { name_to_node[node->name()] = node; } - // Build node_map. + // Build node_map; for each node, create its NodeState and connect its inputs + // and outputs. for (const auto* curr_node : nodes) { auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); const string curr_node_device = DeviceName(curr_node); for (const string& input_node_name : curr_node->input()) { - // Note that input_node_name may be in <node_name>:<output_number> format, - // where ":<output_number>" may be omitted. NodeName() extracts only the - // node_name (prefeix "^", if there was for control input, is also - // deleted). + // Note that input_node_name may be in <prefix><node_name>:<port_num> + // format, where <prefix> (e.g., "^" for control dependency) and + // ":<port_num>" may be omitted. NodeName() extracts only the node_name. const NodeDef* input_node = name_to_node[NodeName(input_node_name)]; + CHECK(input_node); - // Add input_to_curr_node to curr_node's input, and - // add output_to_input_node to input_source_node's output. - // Default values for when input_node and curr_node on the same device. - const NodeDef* input_to_curr_node = input_node; - const NodeDef* input_source_node = input_node; - const NodeDef* output_to_input_node = curr_node; const string in_device = DeviceName(input_node); - if (curr_node_device != in_device) { - if (cached_ops_.count(input_node) > 0 && - cached_ops_[input_node].count(curr_node_device) > 0) { - // Different device, but found an already-transferred copy; connect - // the cached node to curr_node. - input_to_curr_node = cached_ops_[input_node][curr_node_device]; - input_source_node = input_to_curr_node; - output_to_input_node = curr_node; + const auto input_node_port_num = NodePosition(input_node_name); + + if (curr_node_device == in_device) { + // Same device: connect input_node and curr_node directly. + curr_node_state.inputs.push_back( + std::make_pair(input_node, input_node_port_num)); + auto& input_node_state = GetNodeStateOrCreateIt(input_node); + input_node_state.outputs[input_node_port_num].push_back(curr_node); + } else { + if (cached_recv_nodes_.count(input_node) > 0 && + cached_recv_nodes_[input_node].count(curr_node_device) > 0) { + // Different device, but found an already-cached copy (a _Recv op); + // connect the _Recv to curr_node. + const auto* recv_op = + cached_recv_nodes_[input_node][curr_node_device]; + // recv_op's output port is hard-coded to zero. + curr_node_state.inputs.push_back(std::make_pair(recv_op, 0)); + auto& input_node_state = node_map_.at(recv_op); + input_node_state.outputs[0].push_back(curr_node); } else { // Different device, no cached copy; transfer input_node to the // curr_node's device. - auto sendrecv_and_identity = - TransferNode(input_node, curr_node, input_node_name); - const auto* sendrecv = sendrecv_and_identity.first; - const auto* identity = sendrecv_and_identity.second; - input_to_curr_node = identity; - input_source_node = input_node; - output_to_input_node = sendrecv; - - // Cache the identity op for future use. - cached_ops_[input_node][curr_node_device] = identity; + auto send_and_recv = + CreateSendRecv(input_node, curr_node, input_node_name); + // Note that CreateSendRecv() already connected input/output between + // _Send and _Recv ops. + const auto* send = send_and_recv.first; + const auto* recv = send_and_recv.second; + // recv_op's output port is hard-coded to zero. + curr_node_state.inputs.push_back(std::make_pair(recv, 0)); + auto& input_node_state = GetNodeStateOrCreateIt(input_node); + input_node_state.outputs[input_node_port_num].push_back(send); + + // Cache the _Recv op for future use. + cached_recv_nodes_[input_node][curr_node_device] = recv; } } - curr_node_state.inputs.push_back(input_to_curr_node); - - // Note that we do not care output number (in case a tf op has multiple - // outputs), as VirtualScheduler only cares which nodes become ready as - // a node is executed. - auto& input_node_state = GetNodeStateOrCreateIt(input_source_node); - input_node_state.outputs.push_back(output_to_input_node); } if (curr_node->input().empty()) { - curr_node_state.time_ready = - Costs::Duration(); // Node without input: ready at time 0. + // Node without input: ready at time 0. + curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); } + + if (IsPersistentNode(curr_node)) { + auto& device_state = device_[curr_node_device]; + for (int port_num = 0; + port_num < curr_node_state.output_properties.size(); ++port_num) { + device_state.persistent_nodes.insert( + std::make_pair(curr_node, port_num)); + } + } } if (ready_nodes_->Empty()) { @@ -163,18 +179,26 @@ Status VirtualScheduler::Init() { return Status::OK(); } -void VirtualScheduler::MaybeUpdateInputProperties( - const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const { - if (IsSendOp(node) || IsRecvOp(node)) { - // _Send and _Recv ops are inserted from VirtualScheduler, so +void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { + CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init()."; + // This method is called when NodeState is created and adds input and output + // properties for a few exceptional cases that GraphProperties cannot provide + // input/output properties. + if (IsSend(*node) || IsRecv(*node)) { + auto& node_state = node_map_[node]; + auto& inputs = node_state.input_properties; + auto& outputs = node_state.output_properties; + + // _Send and _Recv ops are created from VirtualScheduler, so // there should be no inputs TensorProperties. - CHECK_EQ(inputs->size(), 0); + CHECK(inputs.empty()); + CHECK(outputs.empty()); const auto& attr = node->attr(); // This is the original input source to the _Send and _Recv, and this // string includes "^" if it was control dependency, and output port /// (e.g., ":2") if the input source had multiple outputs. const auto& input_source_name = attr.at(kAttrInputSrc).s(); - if (input_source_name[0] == '^') { + if (IsControlInput(input_source_name)) { // Control dependency; regardless of the input source tensor size, // send 4B. OpInfo::TensorProperties control_message; @@ -182,51 +206,53 @@ void VirtualScheduler::MaybeUpdateInputProperties( control_message.mutable_shape()->add_dim()->set_size(1); auto* value = control_message.mutable_value(); value->add_float_val(1); - inputs->push_back(control_message); + inputs.push_back(control_message); + outputs.push_back(control_message); } else { + auto output_properties = + graph_properties_.GetOutputProperties(NodeName(input_source_name)); // Like with HasInputProperties, if a node does not have output // properties, it's likely it was pruned during the shape inference run. - if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) { - const auto input_position = NodePosition(input_source_name); + if (!output_properties.empty()) { + const auto input_node_port_num = NodePosition(input_source_name); // Use the input source's output property as _Send and _Recv's input // property. - auto outputs = - graph_properties_.GetOutputProperties(NodeName(input_source_name)); - CHECK_GT(outputs.size(), input_position); - inputs->push_back(outputs[input_position]); + CHECK_GT(output_properties.size(), input_node_port_num); + inputs.push_back(output_properties[input_node_port_num]); + outputs.push_back(output_properties[input_node_port_num]); } } } } -bool VirtualScheduler::IsSendOp(const NodeDef* node) const { - return node->op() == kSend; +float VirtualScheduler::Round2(const float x) const { + return std::round(100.0 * x) / 100.0; } -bool VirtualScheduler::IsRecvOp(const NodeDef* node) const { - return node->op() == kRecv; +bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const { + // Variables are persistent nodes. + return IsVariable(*node); } string VirtualScheduler::DeviceName(const NodeDef* node) const { + CHECK(!initialized_) << "DeviceName is called after Init()."; + // TODO(dyoon): integrate this part with VirtualPlacer. - if (IsSendOp(node)) { - const auto& node_state = node_map_.at(node); - const auto* from = node_state.inputs[0]; - const auto* to = node_state.outputs[0]; - return ChannelDeviceName(from, to); - } else { - return node->device().empty() ? "/" + default_device_type_ + ":0" - : node->device(); - } + return node->device().empty() ? "/device:" + default_device_type_ + ":0" + : node->device(); } string VirtualScheduler::ChannelDeviceName(const NodeDef* from, const NodeDef* to) const { + CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; + return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to); } -std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( +std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( const NodeDef* from, const NodeDef* to, const string& input_name) { + CHECK(!initialized_) << "CreateSendRecv is called after Init()."; + // Connect "from" node to "to" node with _Send and _Recv such that // from -> _Send -> _Recv -> to. // _Send is placed on "Channel" device, and _Recv is on the same device @@ -238,11 +264,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( // NodeDefs created here need not be correct: in terms of name, // input names, attrs, etc. + auto input_node_port_num = NodePosition(input_name); + // _Send op. auto* send = new NodeDef(); send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " + DeviceName(to)); - send->set_op(kSend); + send->set_op("_Send"); send->add_input(from->name()); send->set_device(ChannelDeviceName(from, to)); auto& send_attr = *(send->mutable_attr()); @@ -253,19 +281,22 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( // _Recv op. auto* recv = new NodeDef(); recv->set_name("Recv " + from->name() + " on " + DeviceName(to)); - recv->set_op(kRecv); + recv->set_op("_Recv"); recv->add_input(send->name()); recv->set_device(DeviceName(to)); auto& recv_attr = *(recv->mutable_attr()); recv_attr[kAttrInputSrc].set_s(input_name); - // Update NodeState for _Send and _Recv ops. + // NodeState for _Send op. auto& send_node_state = GetNodeStateOrCreateIt(send); - send_node_state.inputs.push_back(from); - send_node_state.outputs.push_back(recv); + send_node_state.device_name = send->device(); // Set Channel device. + send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num)); + send_node_state.outputs[0].push_back(recv); + + // NodeState for _Recv op. auto& recv_node_state = GetNodeStateOrCreateIt(recv); - recv_node_state.inputs.push_back(send); - recv_node_state.outputs.push_back(to); + recv_node_state.inputs.push_back(std::make_pair(send, 0)); + recv_node_state.outputs[0].push_back(to); // Keep the created nodes. additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send)); @@ -277,13 +308,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( NodeInfo VirtualScheduler::GetCurrNodeInfo() const { const NodeDef* node = ready_nodes_->GetCurrNode(); - std::vector<OpInfo::TensorProperties> inputs = - graph_properties_.GetInputProperties(node->name()); - // Some ops created within VirtualScheduler may need further processing to - // the input properties. - MaybeUpdateInputProperties(node, &inputs); - // This is for compatibility; we can just use palcer_->get_device() for all + // This is for compatibility; we can just use placer_->get_device() for all // cases, once VirtualCluster is properly set up. DeviceProperties device; if (placer_) { @@ -294,7 +320,8 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { int device_id; DeviceNameUtils::ParsedName parsed; if (!node->device().empty() && - DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) { + DeviceNameUtils::ParseFullName(node_map_.at(node).device_name, + &parsed)) { device_type = parsed.type; device_id = parsed.id; } else { @@ -309,79 +336,109 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { } // Special case for _Send op. - if (IsSendOp(node)) { + if (IsSend(*node)) { device.set_type(kChannelDevice); } + // Construct NodeInfo. + const auto& node_state = node_map_.at(node); NodeInfo node_info; node_info.name = node->name(); - node_info.device_name = graph_properties_.GetDeviceName(node->name()); - std::vector<OpInfo::TensorProperties> outputs = - graph_properties_.GetOutputProperties(node->name()); + node_info.device_name = node_state.device_name; auto& op_info = node_info.op_info; op_info.set_op(node->op()); *op_info.mutable_attr() = node->attr(); - for (auto& input : inputs) { - op_info.add_inputs()->Swap(&input); + for (auto& input : node_state.input_properties) { + *op_info.add_inputs() = input; } - for (auto& output : outputs) { - op_info.add_outputs()->Swap(&output); + for (auto& output : node_state.output_properties) { + *op_info.add_outputs() = output; } op_info.mutable_device()->Swap(&device); - // add some more to the node_info. return node_info; } NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { + CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init()."; + auto it = node_map_.find(node); if (it == node_map_.end()) { + // Not found; create a NodeState for this node. it = node_map_.emplace(node, NodeState()).first; - } - return it->second; -} + auto& node_state = it->second; + node_state.input_properties = + graph_properties_.GetInputProperties(node->name()); + node_state.output_properties = + graph_properties_.GetOutputProperties(node->name()); + + // Some ops may need further processing to the input / output properties: + // _Send and _Recv. + MaybeUpdateInputOutput(node); + + if (!IsSend(*node)) { + node_state.device_name = DeviceName(node); + // For _Send op, device_name will be set to Channel in CreateSendRecv(). + } -Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, - std::map<string, Costs>* op_cost) { - auto it = op_cost->find(op_name); - if (it == op_cost->end()) { - it = op_cost->emplace(op_name, Costs::ZeroCosts()).first; + // Initialize output port related data: + // Assume the size of OutputProperties represents the number of output ports + // of this node. + for (int i = 0; i < node_state.output_properties.size(); ++i) { + node_state.time_no_references[i] = Costs::Duration::max(); + node_state.num_outputs_executed[i] = 0; + // Populate an empty vector for each port. The caller will add nodes + // that use this port as input. + node_state.outputs[i] = {}; + } + // Port_num -1 is for control dependency. + node_state.time_no_references[-1] = Costs::Duration::max(); + node_state.num_outputs_executed[-1] = 0; + node_state.outputs[-1] = {}; } return it->second; } -bool VirtualScheduler::PopCurrNode() { - const auto* node = ready_nodes_->GetCurrNode(); - auto& node_state = node_map_[node]; - auto& device = device_[DeviceName(node)]; - auto curr_time = device.GetCurrTime(); +int64 VirtualScheduler::CalculateOutputSize( + const std::vector<OpInfo::TensorProperties>& output_properties, + const int port_num) const { + if (port_num < 0) { + return 4; // 4B for control dependency. + } - // Increment num_inputs_ready of the output nodes. - for (auto* output : node_state.outputs) { - auto& output_state = node_map_[output]; - output_state.num_inputs_ready++; - if (output_state.num_inputs_ready == output_state.inputs.size()) { - // This output node is now ready. - output_state.time_ready = curr_time; - ready_nodes_->AddNode(output); - } + if (port_num >= output_properties.size()) { + VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " + << "port_num: " << port_num + << " >= output_properties.size(): " << output_properties.size(); + return 0; } - // Increment num_outputs_executed of the input nodes. - for (auto* input : node_state.inputs) { - auto& input_state = node_map_[input]; - input_state.num_outputs_executed++; - if (input_state.num_outputs_executed == input_state.outputs.size()) { - // All the outputs are executed; no reference to this input nodel - input_state.time_no_reference = curr_time; - // TODO(dyoon): collect device memory usage; note that this input node - // use device memory between time_scheduled and time_no_reference. + const auto& output = output_properties[port_num]; + int64 output_size = DataTypeSize(BaseType(output.dtype())); + + for (const auto& dim : output.shape().dim()) { + auto dim_size = dim.size(); + if (dim_size < 0) { + // Zero output size if there's any unknown dim. + output_size = 0; + VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " + << "unknown dim: " << output_size; + break; } + output_size *= dim_size; } - // Remove the current node; assume FIFO. - ready_nodes_->RemoveCurrNode(); + return output_size; +} - return !ready_nodes_->Empty(); +Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, + std::map<string, Costs>* op_cost) { + auto it = op_cost->find(op_name); + if (it == op_cost->end()) { + // Note that default constructor of Costs sets some memory related fields + // to unknown values so we should explicitly initialize it with ZeroCosts. + it = op_cost->emplace(op_name, Costs::ZeroCosts()).first; + } + return it->second; } bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { @@ -402,7 +459,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { // Update node and device states. auto& node_state = node_map_[node]; - auto& device = device_[DeviceName(node)]; + auto& device = device_[node_state.device_name]; device.nodes_executed.push_back(node); // Node is scheduled when the device is available AND all the inputs are // ready; hence, time_scheduled is time_ready if time_ready > device curr @@ -415,6 +472,21 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { auto curr_time = device.GetCurrTime(); node_state.time_finished = curr_time; + // Update device memory usage. + if (!IsPersistentNode(node)) { + for (const auto& port_num_output_pair : node_state.outputs) { + int port_num = port_num_output_pair.first; + // There's a chance that a specific output is not used at all. + if (node_state.outputs[port_num].empty()) { + node_state.time_no_references[port_num] = curr_time; + } else { + device.memory_usage += + CalculateOutputSize(node_state.output_properties, port_num); + device.nodes_in_memory.insert(std::make_pair(node, port_num)); + } + } + } + // Update device's per-op cost. auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost); device_op_cost = CombineCosts(device_op_cost, node_costs); @@ -425,7 +497,52 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { << ", scheduled: " << node_state.time_scheduled.count() << ", finished: " << node_state.time_finished.count(); - return PopCurrNode(); + // Increment num_inputs_ready of the output nodes + for (const auto& port_num_output_pair : node_state.outputs) { + for (auto* output_node : port_num_output_pair.second) { + auto& output_state = node_map_[output_node]; + output_state.num_inputs_ready++; + if (output_state.num_inputs_ready == output_state.inputs.size()) { + // This output node is now ready. + output_state.time_ready = curr_time; + ready_nodes_->AddNode(output_node); + } + } + } + + // Increment num_outputs_executed of the input nodes. + for (const auto& input_port : node_state.inputs) { + auto* input = input_port.first; + auto port = input_port.second; + auto& input_state = node_map_[input]; + input_state.num_outputs_executed[port]++; + if (input_state.num_outputs_executed[port] == + input_state.outputs[port].size() && + !IsPersistentNode(input)) { + // All the outputs are executed; no reference to this output port of + // input node. + input_state.time_no_references[port] = curr_time; + auto& input_device = device_[input_state.device_name]; + input_device.memory_usage -= + CalculateOutputSize(input_state.output_properties, port); + + input_device.nodes_in_memory.erase(std::make_pair(input, port)); + } + } + + if (!IsPersistentNode(node)) { + // Now that output memory is added and used up nodes are deallocated, + // check max memory usage. + if (device.memory_usage > device.max_memory_usage) { + device.max_memory_usage = device.memory_usage; + device.mem_usage_snapshot_at_peak = device.nodes_in_memory; + } + } + + // Remove the current node; assume FIFO. + ready_nodes_->RemoveCurrNode(); + + return !ready_nodes_->Empty(); } Costs VirtualScheduler::Summary() const { @@ -452,17 +569,59 @@ Costs VirtualScheduler::Summary() const { for (const auto& device : device_) { const auto& name = device.first; const auto& state = device.second; + + std::map<string, int64> op_to_memory; + // First profile only persistent memory usage. + int64 persistent_memory_usage = 0; + std::set<string> persisent_ops; + for (const auto& node_port : state.persistent_nodes) { + const auto* node = node_port.first; + const auto port = node_port.second; + const auto output_size = + CalculateOutputSize(node_map_.at(node).output_properties, port); + persistent_memory_usage += output_size; + op_to_memory[node->op()] += output_size; + persisent_ops.insert(node->op()); + } + int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage; + VLOG(1) << "Device = " << name << ", num_nodes = " << state.nodes_executed.size() - << ", execution_time = " << state.GetCurrTime().count(); - VLOG(1) << "Per-op execution time:"; + << ", execution_time = " << state.GetCurrTime().count() + << ", memory usage: " + << "persistenst = " + << Round2(persistent_memory_usage / 1024.0 / 1024.0 / 1024.0) + << " GB, peak = " + << Round2(state.max_memory_usage / 1024.0 / 1024.0 / 1024.0) + << " GB, total = " + << Round2(max_memory_usage / 1024.0 / 1024.0 / 1024.0) + << " GB, at the end: " << state.memory_usage << " B"; + + VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):"; + // Profile non-persistent op memory usage. + for (const auto& node_port : state.mem_usage_snapshot_at_peak) { + const auto* node = node_port.first; + const auto port = node_port.second; + op_to_memory[node->op()] += + CalculateOutputSize(node_map_.at(node).output_properties, port); + } for (const auto& op_cost_pair : state.op_to_cost) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); - if (cost) { // Skip printing out zero-cost ops. - VLOG(1) << " + " << op << " : " << cost; + const float mem_usage_gb = + Round2(op_to_memory[op] / 1024.0 / 1024.0 / 1024.0); + int64 op_mem_usage = op_to_memory.at(op); + const float mem_usage_percent = + max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage) + : 0.0; + if (cost || mem_usage_percent > 1.0) { + // Print out only non-zero cost ops or ops with > 1% memory usage. + VLOG(1) << " + " << op << " : " << cost << " (" << mem_usage_gb + << " GB [" << mem_usage_percent << "%] " + << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); } } + VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):"; if (critical_path_costs.execution_time <= state.GetCurrTime()) { critical_path_costs = state.device_costs; } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 7764bdc478..af5434efe3 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -29,36 +29,79 @@ namespace tensorflow { namespace grappler { struct NodeState { - std::vector<const NodeDef*> inputs; - std::vector<const NodeDef*> outputs; + // A node (i.e., an op) takes a set of input:port pairs and produces + // a set of output ports. + + // Cross references to input and output nodes from graphdef. + std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs. + // List of output nodes (a list of nodes that takes this output port as input) + // keyed by port_num. Note that port_num -1 is used for control dependency. + std::unordered_map<int, std::vector<const NodeDef*>> outputs; + + // Info from GraphProperties. + std::vector<OpInfo::TensorProperties> input_properties; + std::vector<OpInfo::TensorProperties> output_properties; + + // Canonical device name used within VirtualScheduler. + string device_name; + + // States updated as scheduling nodes. int num_inputs_ready; - int num_outputs_executed; + std::unordered_map<int, int> num_outputs_executed; Costs::Duration time_ready; Costs::Duration time_scheduled; Costs::Duration time_finished; - Costs::Duration time_no_reference; + // Time that all the consumers are executed (hence, no need to keep this + // output in memory), keyed by port_num. + std::unordered_map<int, Costs::Duration> time_no_references; + + // Note that a node may have multiple output ports. The length of outputs, + // num_outputs_executed, and time_no_references should be + // identical when a NodeState is fully initialized. + // They should be 1 + output_properties.size() as we add [-1] for control + // dependency. // Node will be ready to be executed at time_ready, scheduled at // time_scheduled, and finishes execution at time_finished. - // Between time_scheduled and time_no_reference, the node's output tensor - // needs to be on the device, using up device memory. + // Each output port uses up memory space from time_scheduled to its + // time_no_references. NodeState() { num_inputs_ready = 0; - num_outputs_executed = 0; time_ready = Costs::Duration::max(); time_scheduled = Costs::Duration::max(); time_finished = Costs::Duration::max(); - time_no_reference = Costs::Duration::max(); + // Note that num_outputs_executed and time_no_references are not initialized + // here, since we don't know the size (i.e., # outputs for this node). } }; struct DeviceState { + // Nodes executed on this device in execution order. std::vector<const NodeDef*> nodes_executed; - Costs device_costs; - std::map<string, Costs> op_to_cost; // Per-op cost. - DeviceState() { device_costs = Costs::ZeroCosts(); } + // Nodes currently allocated in memory: set of NodeDef* and port_num pairs + // so that we can track which output of the node is in memory. + std::set<std::pair<const NodeDef*, int>> nodes_in_memory; + + // Nodes allocated in memory persistently: e.g., Variables. + std::set<std::pair<const NodeDef*, int>> persistent_nodes; + + // Snapshot of nodes_in_memory, when memory usage is at peak. + // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. + std::set<std::pair<const NodeDef*, int>> mem_usage_snapshot_at_peak; + + Costs device_costs; + std::map<string, Costs> op_to_cost; // Per-op cost. + std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage. + int64 memory_usage; + int64 max_memory_usage; + + DeviceState() { + device_costs = Costs::ZeroCosts(); + memory_usage = 0; + max_memory_usage = 0; + } Costs::Duration GetCurrTime() const { return device_costs.execution_time; } }; @@ -106,48 +149,74 @@ class VirtualScheduler { const string& default_device_type, Cluster* cluster, VirtualPlacer* placer); + // Initializes NodeState and DeviceState from grappler_item_ and + // graph_properties_. Status Init(); NodeInfo GetCurrNodeInfo() const; + + // Returns true if there is any node to be scheduled. bool MarkCurrNodeExecuted(const Costs& node_costs); + // Prints out summary of execution (timing, memory usage, etc.) Costs Summary() const; + protected: + // GetDeviceStates and GetNodeStates are currently for testing purpuse only. + // Retrieves detailed scheduling results. + const std::unordered_map<string, DeviceState>& GetDeviceStates() const { + return device_; + } + const std::unordered_map<const NodeDef*, NodeState>& GetNodeStates() const { + return node_map_; + } + + // Returns the size of output at port_num (unit: bytes). A special case is + // port_num -1, which is for control dependency and assumed to be 4 bytes. + int64 CalculateOutputSize( + const std::vector<OpInfo::TensorProperties>& output_properties, + const int port_num) const; + private: - const string kSend = "_Send"; - const string kRecv = "_Recv"; + // Constants. const string kAttrInputSrc = "input_source_"; const string kAttrSrcDevice = "src_device_"; const string kAttrDstDevice = "dst_device_"; const string kChannelDevice = "Channel"; - void MaybeUpdateInputProperties( - const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const; + // Methods called from Init(). Fails if initialize_ is set. + void MaybeUpdateInputOutput(const NodeDef* node); NodeState& GetNodeStateOrCreateIt(const NodeDef* node); - std::pair<const NodeDef*, const NodeDef*> TransferNode( + std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( const NodeDef* from, const NodeDef* to, const string& input_name); string DeviceName(const NodeDef* node) const; string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; + + // Helper methods. Costs& FindOrCreateZero(const string& op_name, std::map<string, Costs>* op_cost); + float Round2(const float x) const; + bool IsPersistentNode(const NodeDef* node) const; - bool PopCurrNode(); - bool IsSendOp(const NodeDef* node) const; - bool IsRecvOp(const NodeDef* node) const; - - GraphProperties graph_properties_; - std::map<string, int> op_counts_; // Op counts with key with input shape. - std::map<string, int> op_costs_; // Individual op costs (with input shapes). - Costs graph_costs_; // Graph cost. - std::map<string, Costs> op_to_cost_; // Per-op cost. + // Scheduler states: std::unique_ptr<ReadyNodeManager> ready_nodes_; std::unordered_map<const NodeDef*, NodeState> node_map_; std::unordered_map<string, DeviceState> device_; + // Pool of NodeDefs for SendRecv and Identity ops created. std::vector<std::unique_ptr<NodeDef>> additional_nodes_; - // Cache of ops transferred to another device. + // Cache of nodes transferred to another device. std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>> - cached_ops_; + cached_recv_nodes_; + + // Stats: + std::map<string, int> op_counts_; // Op counts with key with input shape. + std::map<string, int> op_costs_; // Individual op costs (with input shapes). + Costs graph_costs_; // Graph cost. + std::map<string, Costs> op_to_cost_; // Per-op cost. + + // Auxilliary data structures for constructing NodeState and DeviceState. + GraphProperties graph_properties_; Cluster* cluster_; // Not owned. const GrapplerItem* grappler_item_; // Not owned. bool use_static_shapes_; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index dad2104b75..10e181e826 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -23,42 +23,49 @@ limitations under the License. namespace tensorflow { namespace grappler { +// Class for testing virtual scheduler. +class TestVirtualScheduler : public VirtualScheduler { + public: + TestVirtualScheduler(const GrapplerItem* grappler_item, + const bool use_static_shapes, + const string& default_device_type, Cluster* cluster, + VirtualPlacer* placer) + : VirtualScheduler(grappler_item, use_static_shapes, default_device_type, + cluster, placer) {} + + FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize); + FRIEND_TEST(VirtualSchedulerTest, MemoryUsage); + FRIEND_TEST(VirtualSchedulerTest, ControlDependency); + FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); + FRIEND_TEST(VirtualSchedulerTest, Variable); +}; class VirtualSchedulerTest : public ::testing::Test { protected: + const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0"; + void SetUp() override { // Initializes cluster_ and placer_. std::unordered_map<string, DeviceProperties> devices; DeviceProperties cpu_device; cpu_device.set_type("CPU"); - devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; - DeviceProperties gpu_device; - gpu_device.set_type("GPU"); - devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device; + devices[kCPU0] = cpu_device; cluster_.reset(new VirtualCluster(devices)); placer_.reset(new VirtualPlacer(cluster_.get())); } - void CreateSchedulerWithConv2Ds() { - // Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in - // fetch nodes. - const int bs = 4; - const int width = 10; - const int height = 10; - const int depth_in = 8; - const int kernel = 3; - const int depth_out = 16; - - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // Three Conv2Ds with only two in fetch nodes. + void CreateGrapplerItemWithConv2Ds() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); auto x = tensorflow::ops::RandomUniform( - s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT); + s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); auto y = tensorflow::ops::RandomUniform( - s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT); + s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); auto z = tensorflow::ops::RandomUniform( - s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT); + s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); auto f = tensorflow::ops::RandomUniform( - s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT); + s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); std::vector<int> strides = {1, 1, 1, 1}; auto c0 = tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); @@ -68,47 +75,253 @@ class VirtualSchedulerTest : public ::testing::Test { tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); - LOG(INFO) << def.DebugString(); grappler_item_.reset(new GrapplerItem); grappler_item_->id = "test_conv2d_graph"; grappler_item_->graph = def; grappler_item_->fetch = {"c0", "c1"}; - scheduler_.reset(new VirtualScheduler( + dependency_["c0"] = {"x", "f"}; + dependency_["c1"] = {"y", "f"}; + } + + // A Conv2D with a variable. + void CreateGrapplerItemWithConv2DAndVariable() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + auto x = tensorflow::ops::RandomUniform( + s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); + auto f = tensorflow::ops::Variable( + s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); + std::vector<int> strides = {1, 1, 1, 1}; + auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME"); + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_conv2d_var_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"y"}; + + dependency_["y"] = {"x", "f"}; + } + + // AddN that takes 4 tensors with 10x10x10x10. + void CreateGrapplerItemWithAddN() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, + DT_FLOAT); + auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, + DT_FLOAT); + auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, + DT_FLOAT); + auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, + DT_FLOAT); + tensorflow::OutputList input_tensors = {x, y, z, w}; + auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors); + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_addn_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"out"}; + + dependency_["out"] = {"x", "y", "z", "w"}; + } + + // NoOp that takes 7 NoOps as control dependency. + void CreateGrapplerItemWithControlDependency() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"}; + std::vector<tensorflow::Operation> input_tensors; + for (const auto& input : input_noop_names) { + auto x = tensorflow::ops::NoOp(s.WithOpName(input)); + input_tensors.push_back(x.operation); + } + auto out = tensorflow::ops::NoOp( + s.WithControlDependencies(input_tensors).WithOpName("out")); + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_control_dependency_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"out"}; + + dependency_["out"] = input_noop_names; + } + + // FusedBN [an op with multiple outputs] with multiple consumers (including + // control dependency). + void CreateGrapplerItemWithBatchNorm() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + auto x = tensorflow::ops::RandomUniform( + s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); + auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"), + {depth_in_}, DT_FLOAT); + auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"), + {depth_in_}, DT_FLOAT); + auto mean = + tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); + auto var = + tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); + + auto batch_norm = tensorflow::ops::FusedBatchNorm( + s.WithOpName("bn"), x, scale, offset, mean, var, + ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); + auto y = batch_norm.y; + auto batch_mean = batch_norm.batch_mean; + auto batch_var = batch_norm.batch_variance; + + auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y); + auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var); + auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var); + std::vector<tensorflow::Operation> input_tensors = { + batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(), + }; + auto z4 = tensorflow::ops::NoOp( + s.WithControlDependencies(batch_var).WithOpName("z4")); + + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_complex_dependency_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"z1", "z2", "z3", "z4"}; + + dependency_["bn"] = {"x", "scale", "offset", "mean", "var"}; + dependency_["z1"] = {"x", "bn"}; + dependency_["z2"] = {"bn"}; + dependency_["z3"] = {"bn"}; + dependency_["z4"] = {"bn"}; + } + + // Call this after creating grappler_item_ and setting up dependency_. + void InitScheduler() { + scheduler_.reset(new TestVirtualScheduler( grappler_item_.get(), true /* use_static_shapes */, "CPU" /* default_device_type */, cluster_.get(), placer_.get())); TF_CHECK_OK(scheduler_->Init()); } + // Call this after init scheduler_. Scheduler stops after executing + // target_node. + std::unordered_map<string, NodeInfo> RunScheduler(const string& target_node) { + Costs zero_costs = Costs::ZeroCosts(); + std::unordered_map<string, NodeInfo> ops_executed; + bool more_nodes = true; + do { + NodeInfo node_info = scheduler_->GetCurrNodeInfo(); + ops_executed[node_info.name] = node_info; + + // Check scheduling order. + auto it = dependency_.find(node_info.name); + if (it != dependency_.end()) { + for (const auto& preceding_node : it->second) { + EXPECT_GT(ops_executed.count(preceding_node), 0); + } + } + more_nodes = scheduler_->MarkCurrNodeExecuted(zero_costs); + + if (node_info.name == target_node) { + // Scheduler has the state after executing the target node. + break; + } + } while (more_nodes); + return ops_executed; + } + + // Helper method for validating a vector. + template <typename T> + void ExpectVectorEq(const std::vector<T>& expected, + const std::vector<T>& test_elements) { + // Set of expected elements for an easy comparison. + std::set<T> expected_set(expected.begin(), expected.end()); + for (const auto& element : test_elements) { + EXPECT_GT(expected_set.count(element), 0); + } + EXPECT_EQ(expected.size(), test_elements.size()); + } + + // Helper method that checks the name of nodes. + void ValidateNodeDefs(const std::vector<string>& expected, + const std::vector<const NodeDef*>& node_defs) { + std::vector<string> node_names; + std::transform(node_defs.begin(), node_defs.end(), + std::back_inserter(node_names), + [](const NodeDef* node) { return node->name(); }); + ExpectVectorEq(expected, node_names); + } + + // Helper method for validating a set. + template <typename T> + void ExpectSetEq(const std::set<T>& expected, + const std::set<T>& test_elements) { + for (const auto& element : test_elements) { + EXPECT_GT(expected.count(element), 0); + } + EXPECT_EQ(expected.size(), test_elements.size()); + } + + // Helper method tthat checks name - port pairs. + void ValidateMemoryUsageSnapshot( + const std::vector<string>& expected_names, const int port_num_expected, + const std::set<std::pair<const NodeDef*, int>>& mem_usage_snapshot) { + std::set<std::pair<string, int>> nodes_at_peak_mem_usage; + std::transform( + mem_usage_snapshot.begin(), mem_usage_snapshot.end(), + std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()), + [](const std::pair<const NodeDef*, int>& node_port) { + return std::make_pair(node_port.first->name(), node_port.second); + }); + std::set<std::pair<string, int>> expected; + std::transform(expected_names.begin(), expected_names.end(), + std::inserter(expected, expected.begin()), + [port_num_expected](const string& name) { + return std::make_pair(name, port_num_expected); + }); + ExpectSetEq(expected, nodes_at_peak_mem_usage); + } + + // Helper method for converting shape vector to TensorProperty. + OpInfo::TensorProperties ShapeToTensorProperty( + const std::vector<int> shape, const DataType& data_type) const { + OpInfo::TensorProperties tensor_property; + tensor_property.set_dtype(data_type); + for (const auto& x : shape) { + tensor_property.mutable_shape()->add_dim()->set_size(x); + } + return tensor_property; + } + // SetUp() inits cluster_ and placer_. std::unique_ptr<VirtualCluster> cluster_; std::unique_ptr<VirtualPlacer> placer_; // grappler_item_ and scheduler_ will be initialized differently for each test - // case + // case. std::unique_ptr<GrapplerItem> grappler_item_; - std::unique_ptr<VirtualScheduler> scheduler_; + std::unique_ptr<TestVirtualScheduler> scheduler_; + // Node name -> its preceding nodes map for testing scheduling order. + std::unordered_map<string, std::vector<string>> dependency_; + + // Shared params for Conv2D related graphs: + const int batch_size_ = 4; + const int width_ = 10; + const int height_ = 10; + const int depth_in_ = 8; + const int kernel_ = 3; + const int depth_out_ = 16; }; TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { - CreateSchedulerWithConv2Ds(); // init scheduler_. - - Costs zero_costs = Costs::ZeroCosts(); - std::unordered_map<string, NodeInfo> ops_executed; - do { - NodeInfo node_info = scheduler_->GetCurrNodeInfo(); - ops_executed[node_info.name] = node_info; - - // Check scheduling order: x and f before c0, and y and f before c1. - if (node_info.name == "c0") { - EXPECT_GT(ops_executed.count("x"), 0); - EXPECT_GT(ops_executed.count("f"), 0); - } else if (node_info.name == "c1") { - EXPECT_GT(ops_executed.count("y"), 0); - EXPECT_GT(ops_executed.count("f"), 0); - } - } while (scheduler_->MarkCurrNodeExecuted(zero_costs)); + // Init. + CreateGrapplerItemWithConv2Ds(); + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); // Run all the nodes. // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be // executed. @@ -132,5 +345,162 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); } + +TEST_F(VirtualSchedulerTest, CalculateOutputSize) { + // Init. + CreateGrapplerItemWithAddN(); + InitScheduler(); + + // Create a set of tensor properties. + std::vector<OpInfo::TensorProperties> output; + output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0 + output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1 + output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2 + output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3 + output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4 + output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4 + + // port_num -1 is for control dependency: hard coded 4B. + EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1)); + + // Test valid outputs. + EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0)); + EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1)); + EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2)); + EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3)); + + // Any uknown shape (-1) shall yield zero output size. + EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4)); + EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5)); + + // Invalid port_num (though it may be an error) shall yield zero + // output size. + EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6)); +} + +TEST_F(VirtualSchedulerTest, MemoryUsage) { + // Init. + CreateGrapplerItemWithAddN(); + InitScheduler(); + + // Run the scheduler. + RunScheduler(""); + + const auto& device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states.at(kCPU0); + + // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage + // is 4 x the input tensor size while executing the out node. + int64 one_input_node_size = 4 * 10 * 10 * 10 * 10; + const std::vector<string> expected_names = {"x", "y", "z", "w"}; + EXPECT_EQ(expected_names.size() * one_input_node_size, + cpu_state.max_memory_usage); + ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */, + cpu_state.mem_usage_snapshot_at_peak); +} + +TEST_F(VirtualSchedulerTest, ControlDependency) { + // Init. + CreateGrapplerItemWithControlDependency(); + InitScheduler(); + + // Run the scheduler. + RunScheduler(""); + + const auto& device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states.at(kCPU0); + + // The graph has a NoOp that takes control dependency from 7 NoOps. The peak + // memory usage is when executing the final NoOp. + int64 one_input_node_size = 4; // control dependency + const std::vector<string> expected_names = {"x", "y", "z", "w", + "u", "v", "t"}; + EXPECT_EQ(expected_names.size() * one_input_node_size, + cpu_state.max_memory_usage); + ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */, + cpu_state.mem_usage_snapshot_at_peak); +} + +TEST_F(VirtualSchedulerTest, ComplexDependency) { + // Init. + CreateGrapplerItemWithBatchNorm(); + InitScheduler(); + + // Run the scheduler. + RunScheduler("bn"); + + const auto& device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states.at(kCPU0); + + // The graph is + // bn = FusedBatchNorm(x, scale, offset, mean, var) + // z1 = bn.y + x + // z2 = bn.var + bn.var + // z3 = bn.var + bn.var + // z4 = control dependency from bn. + // Note that bn.mean doesn't have any consumer. + const int x_size = batch_size_ * width_ * height_ * depth_in_; + int64 expected_size = + 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ + + 1 /* control dependency */); + EXPECT_EQ(expected_size, cpu_state.memory_usage); + + // Nodes currrently in memory: bn's port -1, 0, and 2, and x's port 0. + std::set<std::pair<string, int>> nodes_in_memory; + std::transform( + cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(), + std::inserter(nodes_in_memory, nodes_in_memory.begin()), + [](const std::pair<const NodeDef*, int>& node_port) { + return std::make_pair(node_port.first->name(), node_port.second); + }); + std::set<std::pair<string, int>> expected = { + std::make_pair("bn", -1), std::make_pair("bn", 0), + std::make_pair("bn", 2), std::make_pair("x", 0), + }; + ExpectSetEq(expected, nodes_in_memory); + + const auto& node_states = scheduler_->GetNodeStates(); + const NodeState* bn_node = nullptr; + const NodeState* x_node = nullptr; + for (const auto& nodedef_node_state : node_states) { + const NodeDef* node = nodedef_node_state.first; + const NodeState& node_state = nodedef_node_state.second; + if (node->name() == "bn") { + bn_node = &node_state; + } + if (node->name() == "x") { + x_node = &node_state; + } + } + CHECK_NOTNULL(bn_node); + CHECK_NOTNULL(x_node); + + ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0)); + ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1)); + ValidateNodeDefs({"z1"}, bn_node->outputs.at(0)); + // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2. + ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2)); +} + +TEST_F(VirtualSchedulerTest, Variable) { + // Init. + CreateGrapplerItemWithConv2DAndVariable(); + InitScheduler(); + + // Run the scheduler. + RunScheduler(""); + + const auto& device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states.at(kCPU0); + + // There is one Conv2D that takes x and f, but f is variable, so it should be + // in persistent nodes. + // f is variable. + ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */, + cpu_state.persistent_nodes); + // Only x in peak memory usage snapshot. + ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */, + cpu_state.mem_usage_snapshot_at_peak); +} } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 6f5310f501..51146011b0 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -45,18 +45,33 @@ bool IsMerge(const NodeDef& node) { return op == "Merge"; } +bool IsNoOp(const NodeDef& node) { + const auto op = node.op(); + return op == "NoOp"; +} + bool IsPlaceholder(const NodeDef& node) { const auto op = node.op(); return op == "Placeholder" || op == "PlaceholderV2" || op == "PlaceholderWithDefault"; } +bool IsRecv(const NodeDef& node) { + const auto op = node.op(); + return op == "_Recv"; +} + bool IsReduction(const NodeDef& node) { const auto& op = node.op(); return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || op == "Mean" || op == "Any" || op == "All"; } +bool IsSend(const NodeDef& node) { + const auto op = node.op(); + return op == "_Send"; +} + bool IsSwitch(const NodeDef& node) { const auto& op = node.op(); return op == "Switch"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 483ef5c577..b2102c688d 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -26,8 +26,11 @@ bool IsConstant(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsIdentity(const NodeDef& node); bool IsMerge(const NodeDef& node); +bool IsNoOp(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); +bool IsRecv(const NodeDef& node); bool IsReduction(const NodeDef& node); +bool IsSend(const NodeDef& node); bool IsSwitch(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsVariable(const NodeDef& node); |