aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-06 16:45:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 16:49:04 -0700
commit8f89b654f4d49a1b5d4462303ef27f7f7a2958b3 (patch)
tree4be61a6867376b086c1cb0b770f38c408df91b69
parent0ea0bf5aae2961be4edbe00c205bed01d293dce3 (diff)
Profile memory usage in VirtualScheduler and report peak memory usage.
To do so, NodeState now handles different output ports of a node (in case a node has multiple outputs). Also, VirtualScheduler code is cleaned up with more comments. PiperOrigin-RevId: 158209068
-rw-r--r--tensorflow/core/grappler/costs/BUILD5
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc8
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc425
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h123
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc450
-rw-r--r--tensorflow/core/grappler/op_types.cc15
-rw-r--r--tensorflow/core/grappler/op_types.h3
7 files changed, 827 insertions, 202 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 8455c465df..1f90694331 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -176,6 +176,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/costs:cost_estimator",
@@ -192,6 +193,10 @@ cc_test(
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/clusters:virtual_cluster",
],
)
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 11a57921e5..75ff75123e 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -31,6 +31,8 @@ constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
constexpr char kRecv[] = "_Recv";
constexpr char kBatchMatMul[] = "BatchMatMul";
+constexpr char kVariable[] = "Variable";
+constexpr char kVariableV2[] = "VariableV2";
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
@@ -53,6 +55,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
}
@@ -567,7 +571,7 @@ int64 OpLevelCostEstimator::CalculateSingleInputSize(
for (const auto& dim : input_shape.dim()) {
input_size *= dim.size();
}
- return input_size * DataTypeSize(input.dtype());
+ return input_size * DataTypeSize(BaseType(input.dtype()));
}
int64 OpLevelCostEstimator::CalculateInputSize(
@@ -589,7 +593,7 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
for (const auto& output : op_features.outputs()) {
DataType dt = output.dtype();
const auto& original_output_shape = output.shape();
- int64 output_size = DataTypeSize(dt);
+ int64 output_size = DataTypeSize(BaseType(dt));
int num_dims = std::max(1, original_output_shape.dim_size());
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
found_unknown_shapes);
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 8d8d246078..dfa8c768e7 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -55,10 +56,10 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
const string& default_device_type,
Cluster* cluster, VirtualPlacer* placer)
- : graph_properties_(*grappler_item),
- graph_costs_(Costs::ZeroCosts()),
- // TODO(dyoon): Use a better way than FIFO.
+ : // TODO(dyoon): Use a better way than FIFO.
ready_nodes_(new FIFOManager()),
+ graph_costs_(Costs::ZeroCosts()),
+ graph_properties_(*grappler_item),
cluster_(cluster),
grappler_item_(grappler_item),
use_static_shapes_(use_static_shapes),
@@ -68,6 +69,11 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
}
Status VirtualScheduler::Init() {
+ // Init() preprocesses the input grappler_item and graph_properties to extract
+ // necessary information for emulating tensorflow op scheduling and
+ // construct internal data structures (NodeState and DeviceState) for virtual
+ // scheduling.
+
// Construct graph properties.
Status status;
if (use_static_shapes_) {
@@ -82,13 +88,12 @@ Status VirtualScheduler::Init() {
const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch;
- // First, get the nodes that would run to output fetch_nodes.
+ // Get the nodes that would run to output fetch_nodes.
std::vector<const NodeDef*> nodes =
ComputeTransitiveFanin(graph, fetch_nodes);
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
// ComputeTransitiveFanin().
- //
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached
// from the fetch nodes are scheduled. So the scheduled nodes should be
// exactly the same as those executed for real. One possible discrepancy could
@@ -98,61 +103,72 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node;
}
- // Build node_map.
+ // Build node_map; for each node, create its NodeState and connect its inputs
+ // and outputs.
for (const auto* curr_node : nodes) {
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
const string curr_node_device = DeviceName(curr_node);
for (const string& input_node_name : curr_node->input()) {
- // Note that input_node_name may be in <node_name>:<output_number> format,
- // where ":<output_number>" may be omitted. NodeName() extracts only the
- // node_name (prefeix "^", if there was for control input, is also
- // deleted).
+ // Note that input_node_name may be in <prefix><node_name>:<port_num>
+ // format, where <prefix> (e.g., "^" for control dependency) and
+ // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
+
CHECK(input_node);
- // Add input_to_curr_node to curr_node's input, and
- // add output_to_input_node to input_source_node's output.
- // Default values for when input_node and curr_node on the same device.
- const NodeDef* input_to_curr_node = input_node;
- const NodeDef* input_source_node = input_node;
- const NodeDef* output_to_input_node = curr_node;
const string in_device = DeviceName(input_node);
- if (curr_node_device != in_device) {
- if (cached_ops_.count(input_node) > 0 &&
- cached_ops_[input_node].count(curr_node_device) > 0) {
- // Different device, but found an already-transferred copy; connect
- // the cached node to curr_node.
- input_to_curr_node = cached_ops_[input_node][curr_node_device];
- input_source_node = input_to_curr_node;
- output_to_input_node = curr_node;
+ const auto input_node_port_num = NodePosition(input_node_name);
+
+ if (curr_node_device == in_device) {
+ // Same device: connect input_node and curr_node directly.
+ curr_node_state.inputs.push_back(
+ std::make_pair(input_node, input_node_port_num));
+ auto& input_node_state = GetNodeStateOrCreateIt(input_node);
+ input_node_state.outputs[input_node_port_num].push_back(curr_node);
+ } else {
+ if (cached_recv_nodes_.count(input_node) > 0 &&
+ cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
+ // Different device, but found an already-cached copy (a _Recv op);
+ // connect the _Recv to curr_node.
+ const auto* recv_op =
+ cached_recv_nodes_[input_node][curr_node_device];
+ // recv_op's output port is hard-coded to zero.
+ curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
+ auto& input_node_state = node_map_.at(recv_op);
+ input_node_state.outputs[0].push_back(curr_node);
} else {
// Different device, no cached copy; transfer input_node to the
// curr_node's device.
- auto sendrecv_and_identity =
- TransferNode(input_node, curr_node, input_node_name);
- const auto* sendrecv = sendrecv_and_identity.first;
- const auto* identity = sendrecv_and_identity.second;
- input_to_curr_node = identity;
- input_source_node = input_node;
- output_to_input_node = sendrecv;
-
- // Cache the identity op for future use.
- cached_ops_[input_node][curr_node_device] = identity;
+ auto send_and_recv =
+ CreateSendRecv(input_node, curr_node, input_node_name);
+ // Note that CreateSendRecv() already connected input/output between
+ // _Send and _Recv ops.
+ const auto* send = send_and_recv.first;
+ const auto* recv = send_and_recv.second;
+ // recv_op's output port is hard-coded to zero.
+ curr_node_state.inputs.push_back(std::make_pair(recv, 0));
+ auto& input_node_state = GetNodeStateOrCreateIt(input_node);
+ input_node_state.outputs[input_node_port_num].push_back(send);
+
+ // Cache the _Recv op for future use.
+ cached_recv_nodes_[input_node][curr_node_device] = recv;
}
}
- curr_node_state.inputs.push_back(input_to_curr_node);
-
- // Note that we do not care output number (in case a tf op has multiple
- // outputs), as VirtualScheduler only cares which nodes become ready as
- // a node is executed.
- auto& input_node_state = GetNodeStateOrCreateIt(input_source_node);
- input_node_state.outputs.push_back(output_to_input_node);
}
if (curr_node->input().empty()) {
- curr_node_state.time_ready =
- Costs::Duration(); // Node without input: ready at time 0.
+ // Node without input: ready at time 0.
+ curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node);
}
+
+ if (IsPersistentNode(curr_node)) {
+ auto& device_state = device_[curr_node_device];
+ for (int port_num = 0;
+ port_num < curr_node_state.output_properties.size(); ++port_num) {
+ device_state.persistent_nodes.insert(
+ std::make_pair(curr_node, port_num));
+ }
+ }
}
if (ready_nodes_->Empty()) {
@@ -163,18 +179,26 @@ Status VirtualScheduler::Init() {
return Status::OK();
}
-void VirtualScheduler::MaybeUpdateInputProperties(
- const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const {
- if (IsSendOp(node) || IsRecvOp(node)) {
- // _Send and _Recv ops are inserted from VirtualScheduler, so
+void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
+ CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
+ // This method is called when NodeState is created and adds input and output
+ // properties for a few exceptional cases that GraphProperties cannot provide
+ // input/output properties.
+ if (IsSend(*node) || IsRecv(*node)) {
+ auto& node_state = node_map_[node];
+ auto& inputs = node_state.input_properties;
+ auto& outputs = node_state.output_properties;
+
+ // _Send and _Recv ops are created from VirtualScheduler, so
// there should be no inputs TensorProperties.
- CHECK_EQ(inputs->size(), 0);
+ CHECK(inputs.empty());
+ CHECK(outputs.empty());
const auto& attr = node->attr();
// This is the original input source to the _Send and _Recv, and this
// string includes "^" if it was control dependency, and output port
/// (e.g., ":2") if the input source had multiple outputs.
const auto& input_source_name = attr.at(kAttrInputSrc).s();
- if (input_source_name[0] == '^') {
+ if (IsControlInput(input_source_name)) {
// Control dependency; regardless of the input source tensor size,
// send 4B.
OpInfo::TensorProperties control_message;
@@ -182,51 +206,53 @@ void VirtualScheduler::MaybeUpdateInputProperties(
control_message.mutable_shape()->add_dim()->set_size(1);
auto* value = control_message.mutable_value();
value->add_float_val(1);
- inputs->push_back(control_message);
+ inputs.push_back(control_message);
+ outputs.push_back(control_message);
} else {
+ auto output_properties =
+ graph_properties_.GetOutputProperties(NodeName(input_source_name));
// Like with HasInputProperties, if a node does not have output
// properties, it's likely it was pruned during the shape inference run.
- if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) {
- const auto input_position = NodePosition(input_source_name);
+ if (!output_properties.empty()) {
+ const auto input_node_port_num = NodePosition(input_source_name);
// Use the input source's output property as _Send and _Recv's input
// property.
- auto outputs =
- graph_properties_.GetOutputProperties(NodeName(input_source_name));
- CHECK_GT(outputs.size(), input_position);
- inputs->push_back(outputs[input_position]);
+ CHECK_GT(output_properties.size(), input_node_port_num);
+ inputs.push_back(output_properties[input_node_port_num]);
+ outputs.push_back(output_properties[input_node_port_num]);
}
}
}
}
-bool VirtualScheduler::IsSendOp(const NodeDef* node) const {
- return node->op() == kSend;
+float VirtualScheduler::Round2(const float x) const {
+ return std::round(100.0 * x) / 100.0;
}
-bool VirtualScheduler::IsRecvOp(const NodeDef* node) const {
- return node->op() == kRecv;
+bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
+ // Variables are persistent nodes.
+ return IsVariable(*node);
}
string VirtualScheduler::DeviceName(const NodeDef* node) const {
+ CHECK(!initialized_) << "DeviceName is called after Init().";
+
// TODO(dyoon): integrate this part with VirtualPlacer.
- if (IsSendOp(node)) {
- const auto& node_state = node_map_.at(node);
- const auto* from = node_state.inputs[0];
- const auto* to = node_state.outputs[0];
- return ChannelDeviceName(from, to);
- } else {
- return node->device().empty() ? "/" + default_device_type_ + ":0"
- : node->device();
- }
+ return node->device().empty() ? "/device:" + default_device_type_ + ":0"
+ : node->device();
}
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const {
+ CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
+
return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to);
}
-std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
+std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
const NodeDef* from, const NodeDef* to, const string& input_name) {
+ CHECK(!initialized_) << "CreateSendRecv is called after Init().";
+
// Connect "from" node to "to" node with _Send and _Recv such that
// from -> _Send -> _Recv -> to.
// _Send is placed on "Channel" device, and _Recv is on the same device
@@ -238,11 +264,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
// NodeDefs created here need not be correct: in terms of name,
// input names, attrs, etc.
+ auto input_node_port_num = NodePosition(input_name);
+
// _Send op.
auto* send = new NodeDef();
send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
DeviceName(to));
- send->set_op(kSend);
+ send->set_op("_Send");
send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to));
auto& send_attr = *(send->mutable_attr());
@@ -253,19 +281,22 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
// _Recv op.
auto* recv = new NodeDef();
recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
- recv->set_op(kRecv);
+ recv->set_op("_Recv");
recv->add_input(send->name());
recv->set_device(DeviceName(to));
auto& recv_attr = *(recv->mutable_attr());
recv_attr[kAttrInputSrc].set_s(input_name);
- // Update NodeState for _Send and _Recv ops.
+ // NodeState for _Send op.
auto& send_node_state = GetNodeStateOrCreateIt(send);
- send_node_state.inputs.push_back(from);
- send_node_state.outputs.push_back(recv);
+ send_node_state.device_name = send->device(); // Set Channel device.
+ send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
+ send_node_state.outputs[0].push_back(recv);
+
+ // NodeState for _Recv op.
auto& recv_node_state = GetNodeStateOrCreateIt(recv);
- recv_node_state.inputs.push_back(send);
- recv_node_state.outputs.push_back(to);
+ recv_node_state.inputs.push_back(std::make_pair(send, 0));
+ recv_node_state.outputs[0].push_back(to);
// Keep the created nodes.
additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
@@ -277,13 +308,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
const NodeDef* node = ready_nodes_->GetCurrNode();
- std::vector<OpInfo::TensorProperties> inputs =
- graph_properties_.GetInputProperties(node->name());
- // Some ops created within VirtualScheduler may need further processing to
- // the input properties.
- MaybeUpdateInputProperties(node, &inputs);
- // This is for compatibility; we can just use palcer_->get_device() for all
+ // This is for compatibility; we can just use placer_->get_device() for all
// cases, once VirtualCluster is properly set up.
DeviceProperties device;
if (placer_) {
@@ -294,7 +320,8 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
int device_id;
DeviceNameUtils::ParsedName parsed;
if (!node->device().empty() &&
- DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) {
+ DeviceNameUtils::ParseFullName(node_map_.at(node).device_name,
+ &parsed)) {
device_type = parsed.type;
device_id = parsed.id;
} else {
@@ -309,79 +336,109 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
}
// Special case for _Send op.
- if (IsSendOp(node)) {
+ if (IsSend(*node)) {
device.set_type(kChannelDevice);
}
+ // Construct NodeInfo.
+ const auto& node_state = node_map_.at(node);
NodeInfo node_info;
node_info.name = node->name();
- node_info.device_name = graph_properties_.GetDeviceName(node->name());
- std::vector<OpInfo::TensorProperties> outputs =
- graph_properties_.GetOutputProperties(node->name());
+ node_info.device_name = node_state.device_name;
auto& op_info = node_info.op_info;
op_info.set_op(node->op());
*op_info.mutable_attr() = node->attr();
- for (auto& input : inputs) {
- op_info.add_inputs()->Swap(&input);
+ for (auto& input : node_state.input_properties) {
+ *op_info.add_inputs() = input;
}
- for (auto& output : outputs) {
- op_info.add_outputs()->Swap(&output);
+ for (auto& output : node_state.output_properties) {
+ *op_info.add_outputs() = output;
}
op_info.mutable_device()->Swap(&device);
- // add some more to the node_info.
return node_info;
}
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
+ CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
+
auto it = node_map_.find(node);
if (it == node_map_.end()) {
+ // Not found; create a NodeState for this node.
it = node_map_.emplace(node, NodeState()).first;
- }
- return it->second;
-}
+ auto& node_state = it->second;
+ node_state.input_properties =
+ graph_properties_.GetInputProperties(node->name());
+ node_state.output_properties =
+ graph_properties_.GetOutputProperties(node->name());
+
+ // Some ops may need further processing to the input / output properties:
+ // _Send and _Recv.
+ MaybeUpdateInputOutput(node);
+
+ if (!IsSend(*node)) {
+ node_state.device_name = DeviceName(node);
+ // For _Send op, device_name will be set to Channel in CreateSendRecv().
+ }
-Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
- std::map<string, Costs>* op_cost) {
- auto it = op_cost->find(op_name);
- if (it == op_cost->end()) {
- it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
+ // Initialize output port related data:
+ // Assume the size of OutputProperties represents the number of output ports
+ // of this node.
+ for (int i = 0; i < node_state.output_properties.size(); ++i) {
+ node_state.time_no_references[i] = Costs::Duration::max();
+ node_state.num_outputs_executed[i] = 0;
+ // Populate an empty vector for each port. The caller will add nodes
+ // that use this port as input.
+ node_state.outputs[i] = {};
+ }
+ // Port_num -1 is for control dependency.
+ node_state.time_no_references[-1] = Costs::Duration::max();
+ node_state.num_outputs_executed[-1] = 0;
+ node_state.outputs[-1] = {};
}
return it->second;
}
-bool VirtualScheduler::PopCurrNode() {
- const auto* node = ready_nodes_->GetCurrNode();
- auto& node_state = node_map_[node];
- auto& device = device_[DeviceName(node)];
- auto curr_time = device.GetCurrTime();
+int64 VirtualScheduler::CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ const int port_num) const {
+ if (port_num < 0) {
+ return 4; // 4B for control dependency.
+ }
- // Increment num_inputs_ready of the output nodes.
- for (auto* output : node_state.outputs) {
- auto& output_state = node_map_[output];
- output_state.num_inputs_ready++;
- if (output_state.num_inputs_ready == output_state.inputs.size()) {
- // This output node is now ready.
- output_state.time_ready = curr_time;
- ready_nodes_->AddNode(output);
- }
+ if (port_num >= output_properties.size()) {
+ VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
+ << "port_num: " << port_num
+ << " >= output_properties.size(): " << output_properties.size();
+ return 0;
}
- // Increment num_outputs_executed of the input nodes.
- for (auto* input : node_state.inputs) {
- auto& input_state = node_map_[input];
- input_state.num_outputs_executed++;
- if (input_state.num_outputs_executed == input_state.outputs.size()) {
- // All the outputs are executed; no reference to this input nodel
- input_state.time_no_reference = curr_time;
- // TODO(dyoon): collect device memory usage; note that this input node
- // use device memory between time_scheduled and time_no_reference.
+ const auto& output = output_properties[port_num];
+ int64 output_size = DataTypeSize(BaseType(output.dtype()));
+
+ for (const auto& dim : output.shape().dim()) {
+ auto dim_size = dim.size();
+ if (dim_size < 0) {
+ // Zero output size if there's any unknown dim.
+ output_size = 0;
+ VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
+ << "unknown dim: " << output_size;
+ break;
}
+ output_size *= dim_size;
}
- // Remove the current node; assume FIFO.
- ready_nodes_->RemoveCurrNode();
+ return output_size;
+}
- return !ready_nodes_->Empty();
+Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
+ std::map<string, Costs>* op_cost) {
+ auto it = op_cost->find(op_name);
+ if (it == op_cost->end()) {
+ // Note that default constructor of Costs sets some memory related fields
+ // to unknown values so we should explicitly initialize it with ZeroCosts.
+ it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
+ }
+ return it->second;
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
@@ -402,7 +459,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update node and device states.
auto& node_state = node_map_[node];
- auto& device = device_[DeviceName(node)];
+ auto& device = device_[node_state.device_name];
device.nodes_executed.push_back(node);
// Node is scheduled when the device is available AND all the inputs are
// ready; hence, time_scheduled is time_ready if time_ready > device curr
@@ -415,6 +472,21 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
auto curr_time = device.GetCurrTime();
node_state.time_finished = curr_time;
+ // Update device memory usage.
+ if (!IsPersistentNode(node)) {
+ for (const auto& port_num_output_pair : node_state.outputs) {
+ int port_num = port_num_output_pair.first;
+ // There's a chance that a specific output is not used at all.
+ if (node_state.outputs[port_num].empty()) {
+ node_state.time_no_references[port_num] = curr_time;
+ } else {
+ device.memory_usage +=
+ CalculateOutputSize(node_state.output_properties, port_num);
+ device.nodes_in_memory.insert(std::make_pair(node, port_num));
+ }
+ }
+ }
+
// Update device's per-op cost.
auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
device_op_cost = CombineCosts(device_op_cost, node_costs);
@@ -425,7 +497,52 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count();
- return PopCurrNode();
+ // Increment num_inputs_ready of the output nodes
+ for (const auto& port_num_output_pair : node_state.outputs) {
+ for (auto* output_node : port_num_output_pair.second) {
+ auto& output_state = node_map_[output_node];
+ output_state.num_inputs_ready++;
+ if (output_state.num_inputs_ready == output_state.inputs.size()) {
+ // This output node is now ready.
+ output_state.time_ready = curr_time;
+ ready_nodes_->AddNode(output_node);
+ }
+ }
+ }
+
+ // Increment num_outputs_executed of the input nodes.
+ for (const auto& input_port : node_state.inputs) {
+ auto* input = input_port.first;
+ auto port = input_port.second;
+ auto& input_state = node_map_[input];
+ input_state.num_outputs_executed[port]++;
+ if (input_state.num_outputs_executed[port] ==
+ input_state.outputs[port].size() &&
+ !IsPersistentNode(input)) {
+ // All the outputs are executed; no reference to this output port of
+ // input node.
+ input_state.time_no_references[port] = curr_time;
+ auto& input_device = device_[input_state.device_name];
+ input_device.memory_usage -=
+ CalculateOutputSize(input_state.output_properties, port);
+
+ input_device.nodes_in_memory.erase(std::make_pair(input, port));
+ }
+ }
+
+ if (!IsPersistentNode(node)) {
+ // Now that output memory is added and used up nodes are deallocated,
+ // check max memory usage.
+ if (device.memory_usage > device.max_memory_usage) {
+ device.max_memory_usage = device.memory_usage;
+ device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
+ }
+ }
+
+ // Remove the current node; assume FIFO.
+ ready_nodes_->RemoveCurrNode();
+
+ return !ready_nodes_->Empty();
}
Costs VirtualScheduler::Summary() const {
@@ -452,17 +569,59 @@ Costs VirtualScheduler::Summary() const {
for (const auto& device : device_) {
const auto& name = device.first;
const auto& state = device.second;
+
+ std::map<string, int64> op_to_memory;
+ // First profile only persistent memory usage.
+ int64 persistent_memory_usage = 0;
+ std::set<string> persisent_ops;
+ for (const auto& node_port : state.persistent_nodes) {
+ const auto* node = node_port.first;
+ const auto port = node_port.second;
+ const auto output_size =
+ CalculateOutputSize(node_map_.at(node).output_properties, port);
+ persistent_memory_usage += output_size;
+ op_to_memory[node->op()] += output_size;
+ persisent_ops.insert(node->op());
+ }
+ int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
+
VLOG(1) << "Device = " << name
<< ", num_nodes = " << state.nodes_executed.size()
- << ", execution_time = " << state.GetCurrTime().count();
- VLOG(1) << "Per-op execution time:";
+ << ", execution_time = " << state.GetCurrTime().count()
+ << ", memory usage: "
+ << "persistenst = "
+ << Round2(persistent_memory_usage / 1024.0 / 1024.0 / 1024.0)
+ << " GB, peak = "
+ << Round2(state.max_memory_usage / 1024.0 / 1024.0 / 1024.0)
+ << " GB, total = "
+ << Round2(max_memory_usage / 1024.0 / 1024.0 / 1024.0)
+ << " GB, at the end: " << state.memory_usage << " B";
+
+ VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
+ // Profile non-persistent op memory usage.
+ for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
+ const auto* node = node_port.first;
+ const auto port = node_port.second;
+ op_to_memory[node->op()] +=
+ CalculateOutputSize(node_map_.at(node).output_properties, port);
+ }
for (const auto& op_cost_pair : state.op_to_cost) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
- if (cost) { // Skip printing out zero-cost ops.
- VLOG(1) << " + " << op << " : " << cost;
+ const float mem_usage_gb =
+ Round2(op_to_memory[op] / 1024.0 / 1024.0 / 1024.0);
+ int64 op_mem_usage = op_to_memory.at(op);
+ const float mem_usage_percent =
+ max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
+ : 0.0;
+ if (cost || mem_usage_percent > 1.0) {
+ // Print out only non-zero cost ops or ops with > 1% memory usage.
+ VLOG(1) << " + " << op << " : " << cost << " (" << mem_usage_gb
+ << " GB [" << mem_usage_percent << "%] "
+ << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
}
}
+ VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
critical_path_costs = state.device_costs;
}
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 7764bdc478..af5434efe3 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -29,36 +29,79 @@ namespace tensorflow {
namespace grappler {
struct NodeState {
- std::vector<const NodeDef*> inputs;
- std::vector<const NodeDef*> outputs;
+ // A node (i.e., an op) takes a set of input:port pairs and produces
+ // a set of output ports.
+
+ // Cross references to input and output nodes from graphdef.
+ std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs.
+ // List of output nodes (a list of nodes that takes this output port as input)
+ // keyed by port_num. Note that port_num -1 is used for control dependency.
+ std::unordered_map<int, std::vector<const NodeDef*>> outputs;
+
+ // Info from GraphProperties.
+ std::vector<OpInfo::TensorProperties> input_properties;
+ std::vector<OpInfo::TensorProperties> output_properties;
+
+ // Canonical device name used within VirtualScheduler.
+ string device_name;
+
+ // States updated as scheduling nodes.
int num_inputs_ready;
- int num_outputs_executed;
+ std::unordered_map<int, int> num_outputs_executed;
Costs::Duration time_ready;
Costs::Duration time_scheduled;
Costs::Duration time_finished;
- Costs::Duration time_no_reference;
+ // Time that all the consumers are executed (hence, no need to keep this
+ // output in memory), keyed by port_num.
+ std::unordered_map<int, Costs::Duration> time_no_references;
+
+ // Note that a node may have multiple output ports. The length of outputs,
+ // num_outputs_executed, and time_no_references should be
+ // identical when a NodeState is fully initialized.
+ // They should be 1 + output_properties.size() as we add [-1] for control
+ // dependency.
// Node will be ready to be executed at time_ready, scheduled at
// time_scheduled, and finishes execution at time_finished.
- // Between time_scheduled and time_no_reference, the node's output tensor
- // needs to be on the device, using up device memory.
+ // Each output port uses up memory space from time_scheduled to its
+ // time_no_references.
NodeState() {
num_inputs_ready = 0;
- num_outputs_executed = 0;
time_ready = Costs::Duration::max();
time_scheduled = Costs::Duration::max();
time_finished = Costs::Duration::max();
- time_no_reference = Costs::Duration::max();
+ // Note that num_outputs_executed and time_no_references are not initialized
+ // here, since we don't know the size (i.e., # outputs for this node).
}
};
struct DeviceState {
+ // Nodes executed on this device in execution order.
std::vector<const NodeDef*> nodes_executed;
- Costs device_costs;
- std::map<string, Costs> op_to_cost; // Per-op cost.
- DeviceState() { device_costs = Costs::ZeroCosts(); }
+ // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
+ // so that we can track which output of the node is in memory.
+ std::set<std::pair<const NodeDef*, int>> nodes_in_memory;
+
+ // Nodes allocated in memory persistently: e.g., Variables.
+ std::set<std::pair<const NodeDef*, int>> persistent_nodes;
+
+ // Snapshot of nodes_in_memory, when memory usage is at peak.
+ // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
+ std::set<std::pair<const NodeDef*, int>> mem_usage_snapshot_at_peak;
+
+ Costs device_costs;
+ std::map<string, Costs> op_to_cost; // Per-op cost.
+ std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage.
+ int64 memory_usage;
+ int64 max_memory_usage;
+
+ DeviceState() {
+ device_costs = Costs::ZeroCosts();
+ memory_usage = 0;
+ max_memory_usage = 0;
+ }
Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
};
@@ -106,48 +149,74 @@ class VirtualScheduler {
const string& default_device_type, Cluster* cluster,
VirtualPlacer* placer);
+ // Initializes NodeState and DeviceState from grappler_item_ and
+ // graph_properties_.
Status Init();
NodeInfo GetCurrNodeInfo() const;
+
+ // Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
+ // Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const;
+ protected:
+ // GetDeviceStates and GetNodeStates are currently for testing purpuse only.
+ // Retrieves detailed scheduling results.
+ const std::unordered_map<string, DeviceState>& GetDeviceStates() const {
+ return device_;
+ }
+ const std::unordered_map<const NodeDef*, NodeState>& GetNodeStates() const {
+ return node_map_;
+ }
+
+ // Returns the size of output at port_num (unit: bytes). A special case is
+ // port_num -1, which is for control dependency and assumed to be 4 bytes.
+ int64 CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ const int port_num) const;
+
private:
- const string kSend = "_Send";
- const string kRecv = "_Recv";
+ // Constants.
const string kAttrInputSrc = "input_source_";
const string kAttrSrcDevice = "src_device_";
const string kAttrDstDevice = "dst_device_";
const string kChannelDevice = "Channel";
- void MaybeUpdateInputProperties(
- const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const;
+ // Methods called from Init(). Fails if initialize_ is set.
+ void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
- std::pair<const NodeDef*, const NodeDef*> TransferNode(
+ std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
const NodeDef* from, const NodeDef* to, const string& input_name);
string DeviceName(const NodeDef* node) const;
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
+
+ // Helper methods.
Costs& FindOrCreateZero(const string& op_name,
std::map<string, Costs>* op_cost);
+ float Round2(const float x) const;
+ bool IsPersistentNode(const NodeDef* node) const;
- bool PopCurrNode();
- bool IsSendOp(const NodeDef* node) const;
- bool IsRecvOp(const NodeDef* node) const;
-
- GraphProperties graph_properties_;
- std::map<string, int> op_counts_; // Op counts with key with input shape.
- std::map<string, int> op_costs_; // Individual op costs (with input shapes).
- Costs graph_costs_; // Graph cost.
- std::map<string, Costs> op_to_cost_; // Per-op cost.
+ // Scheduler states:
std::unique_ptr<ReadyNodeManager> ready_nodes_;
std::unordered_map<const NodeDef*, NodeState> node_map_;
std::unordered_map<string, DeviceState> device_;
+
// Pool of NodeDefs for SendRecv and Identity ops created.
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
- // Cache of ops transferred to another device.
+ // Cache of nodes transferred to another device.
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
- cached_ops_;
+ cached_recv_nodes_;
+
+ // Stats:
+ std::map<string, int> op_counts_; // Op counts with key with input shape.
+ std::map<string, int> op_costs_; // Individual op costs (with input shapes).
+ Costs graph_costs_; // Graph cost.
+ std::map<string, Costs> op_to_cost_; // Per-op cost.
+
+ // Auxilliary data structures for constructing NodeState and DeviceState.
+ GraphProperties graph_properties_;
Cluster* cluster_; // Not owned.
const GrapplerItem* grappler_item_; // Not owned.
bool use_static_shapes_;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index dad2104b75..10e181e826 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -23,42 +23,49 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+// Class for testing virtual scheduler.
+class TestVirtualScheduler : public VirtualScheduler {
+ public:
+ TestVirtualScheduler(const GrapplerItem* grappler_item,
+ const bool use_static_shapes,
+ const string& default_device_type, Cluster* cluster,
+ VirtualPlacer* placer)
+ : VirtualScheduler(grappler_item, use_static_shapes, default_device_type,
+ cluster, placer) {}
+
+ FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
+ FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
+ FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
+ FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
+ FRIEND_TEST(VirtualSchedulerTest, Variable);
+};
class VirtualSchedulerTest : public ::testing::Test {
protected:
+ const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
+
void SetUp() override {
// Initializes cluster_ and placer_.
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
- devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
- DeviceProperties gpu_device;
- gpu_device.set_type("GPU");
- devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
+ devices[kCPU0] = cpu_device;
cluster_.reset(new VirtualCluster(devices));
placer_.reset(new VirtualPlacer(cluster_.get()));
}
- void CreateSchedulerWithConv2Ds() {
- // Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in
- // fetch nodes.
- const int bs = 4;
- const int width = 10;
- const int height = 10;
- const int depth_in = 8;
- const int kernel = 3;
- const int depth_out = 16;
-
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // Three Conv2Ds with only two in fetch nodes.
+ void CreateGrapplerItemWithConv2Ds() {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
auto x = tensorflow::ops::RandomUniform(
- s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT);
+ s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto y = tensorflow::ops::RandomUniform(
- s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT);
+ s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto z = tensorflow::ops::RandomUniform(
- s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT);
+ s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto f = tensorflow::ops::RandomUniform(
- s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT);
+ s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
std::vector<int> strides = {1, 1, 1, 1};
auto c0 =
tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
@@ -68,47 +75,253 @@ class VirtualSchedulerTest : public ::testing::Test {
tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
- LOG(INFO) << def.DebugString();
grappler_item_.reset(new GrapplerItem);
grappler_item_->id = "test_conv2d_graph";
grappler_item_->graph = def;
grappler_item_->fetch = {"c0", "c1"};
- scheduler_.reset(new VirtualScheduler(
+ dependency_["c0"] = {"x", "f"};
+ dependency_["c1"] = {"y", "f"};
+ }
+
+ // A Conv2D with a variable.
+ void CreateGrapplerItemWithConv2DAndVariable() {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = tensorflow::ops::RandomUniform(
+ s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
+ auto f = tensorflow::ops::Variable(
+ s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
+ std::vector<int> strides = {1, 1, 1, 1};
+ auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
+ GraphDef def;
+ TF_CHECK_OK(s.ToGraphDef(&def));
+
+ grappler_item_.reset(new GrapplerItem);
+ grappler_item_->id = "test_conv2d_var_graph";
+ grappler_item_->graph = def;
+ grappler_item_->fetch = {"y"};
+
+ dependency_["y"] = {"x", "f"};
+ }
+
+ // AddN that takes 4 tensors with 10x10x10x10.
+ void CreateGrapplerItemWithAddN() {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10},
+ DT_FLOAT);
+ auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10},
+ DT_FLOAT);
+ auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10},
+ DT_FLOAT);
+ auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10},
+ DT_FLOAT);
+ tensorflow::OutputList input_tensors = {x, y, z, w};
+ auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors);
+ GraphDef def;
+ TF_CHECK_OK(s.ToGraphDef(&def));
+
+ grappler_item_.reset(new GrapplerItem);
+ grappler_item_->id = "test_addn_graph";
+ grappler_item_->graph = def;
+ grappler_item_->fetch = {"out"};
+
+ dependency_["out"] = {"x", "y", "z", "w"};
+ }
+
+ // NoOp that takes 7 NoOps as control dependency.
+ void CreateGrapplerItemWithControlDependency() {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
+ std::vector<tensorflow::Operation> input_tensors;
+ for (const auto& input : input_noop_names) {
+ auto x = tensorflow::ops::NoOp(s.WithOpName(input));
+ input_tensors.push_back(x.operation);
+ }
+ auto out = tensorflow::ops::NoOp(
+ s.WithControlDependencies(input_tensors).WithOpName("out"));
+ GraphDef def;
+ TF_CHECK_OK(s.ToGraphDef(&def));
+
+ grappler_item_.reset(new GrapplerItem);
+ grappler_item_->id = "test_control_dependency_graph";
+ grappler_item_->graph = def;
+ grappler_item_->fetch = {"out"};
+
+ dependency_["out"] = input_noop_names;
+ }
+
+ // FusedBN [an op with multiple outputs] with multiple consumers (including
+ // control dependency).
+ void CreateGrapplerItemWithBatchNorm() {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = tensorflow::ops::RandomUniform(
+ s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
+ auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
+ {depth_in_}, DT_FLOAT);
+ auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
+ {depth_in_}, DT_FLOAT);
+ auto mean =
+ tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
+ auto var =
+ tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
+
+ auto batch_norm = tensorflow::ops::FusedBatchNorm(
+ s.WithOpName("bn"), x, scale, offset, mean, var,
+ ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
+ auto y = batch_norm.y;
+ auto batch_mean = batch_norm.batch_mean;
+ auto batch_var = batch_norm.batch_variance;
+
+ auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y);
+ auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var);
+ auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var);
+ std::vector<tensorflow::Operation> input_tensors = {
+ batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(),
+ };
+ auto z4 = tensorflow::ops::NoOp(
+ s.WithControlDependencies(batch_var).WithOpName("z4"));
+
+ GraphDef def;
+ TF_CHECK_OK(s.ToGraphDef(&def));
+
+ grappler_item_.reset(new GrapplerItem);
+ grappler_item_->id = "test_complex_dependency_graph";
+ grappler_item_->graph = def;
+ grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
+
+ dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
+ dependency_["z1"] = {"x", "bn"};
+ dependency_["z2"] = {"bn"};
+ dependency_["z3"] = {"bn"};
+ dependency_["z4"] = {"bn"};
+ }
+
+ // Call this after creating grappler_item_ and setting up dependency_.
+ void InitScheduler() {
+ scheduler_.reset(new TestVirtualScheduler(
grappler_item_.get(), true /* use_static_shapes */,
"CPU" /* default_device_type */, cluster_.get(), placer_.get()));
TF_CHECK_OK(scheduler_->Init());
}
+ // Call this after init scheduler_. Scheduler stops after executing
+ // target_node.
+ std::unordered_map<string, NodeInfo> RunScheduler(const string& target_node) {
+ Costs zero_costs = Costs::ZeroCosts();
+ std::unordered_map<string, NodeInfo> ops_executed;
+ bool more_nodes = true;
+ do {
+ NodeInfo node_info = scheduler_->GetCurrNodeInfo();
+ ops_executed[node_info.name] = node_info;
+
+ // Check scheduling order.
+ auto it = dependency_.find(node_info.name);
+ if (it != dependency_.end()) {
+ for (const auto& preceding_node : it->second) {
+ EXPECT_GT(ops_executed.count(preceding_node), 0);
+ }
+ }
+ more_nodes = scheduler_->MarkCurrNodeExecuted(zero_costs);
+
+ if (node_info.name == target_node) {
+ // Scheduler has the state after executing the target node.
+ break;
+ }
+ } while (more_nodes);
+ return ops_executed;
+ }
+
+ // Helper method for validating a vector.
+ template <typename T>
+ void ExpectVectorEq(const std::vector<T>& expected,
+ const std::vector<T>& test_elements) {
+ // Set of expected elements for an easy comparison.
+ std::set<T> expected_set(expected.begin(), expected.end());
+ for (const auto& element : test_elements) {
+ EXPECT_GT(expected_set.count(element), 0);
+ }
+ EXPECT_EQ(expected.size(), test_elements.size());
+ }
+
+ // Helper method that checks the name of nodes.
+ void ValidateNodeDefs(const std::vector<string>& expected,
+ const std::vector<const NodeDef*>& node_defs) {
+ std::vector<string> node_names;
+ std::transform(node_defs.begin(), node_defs.end(),
+ std::back_inserter(node_names),
+ [](const NodeDef* node) { return node->name(); });
+ ExpectVectorEq(expected, node_names);
+ }
+
+ // Helper method for validating a set.
+ template <typename T>
+ void ExpectSetEq(const std::set<T>& expected,
+ const std::set<T>& test_elements) {
+ for (const auto& element : test_elements) {
+ EXPECT_GT(expected.count(element), 0);
+ }
+ EXPECT_EQ(expected.size(), test_elements.size());
+ }
+
+ // Helper method tthat checks name - port pairs.
+ void ValidateMemoryUsageSnapshot(
+ const std::vector<string>& expected_names, const int port_num_expected,
+ const std::set<std::pair<const NodeDef*, int>>& mem_usage_snapshot) {
+ std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
+ std::transform(
+ mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
+ std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
+ [](const std::pair<const NodeDef*, int>& node_port) {
+ return std::make_pair(node_port.first->name(), node_port.second);
+ });
+ std::set<std::pair<string, int>> expected;
+ std::transform(expected_names.begin(), expected_names.end(),
+ std::inserter(expected, expected.begin()),
+ [port_num_expected](const string& name) {
+ return std::make_pair(name, port_num_expected);
+ });
+ ExpectSetEq(expected, nodes_at_peak_mem_usage);
+ }
+
+ // Helper method for converting shape vector to TensorProperty.
+ OpInfo::TensorProperties ShapeToTensorProperty(
+ const std::vector<int> shape, const DataType& data_type) const {
+ OpInfo::TensorProperties tensor_property;
+ tensor_property.set_dtype(data_type);
+ for (const auto& x : shape) {
+ tensor_property.mutable_shape()->add_dim()->set_size(x);
+ }
+ return tensor_property;
+ }
+
// SetUp() inits cluster_ and placer_.
std::unique_ptr<VirtualCluster> cluster_;
std::unique_ptr<VirtualPlacer> placer_;
// grappler_item_ and scheduler_ will be initialized differently for each test
- // case
+ // case.
std::unique_ptr<GrapplerItem> grappler_item_;
- std::unique_ptr<VirtualScheduler> scheduler_;
+ std::unique_ptr<TestVirtualScheduler> scheduler_;
+ // Node name -> its preceding nodes map for testing scheduling order.
+ std::unordered_map<string, std::vector<string>> dependency_;
+
+ // Shared params for Conv2D related graphs:
+ const int batch_size_ = 4;
+ const int width_ = 10;
+ const int height_ = 10;
+ const int depth_in_ = 8;
+ const int kernel_ = 3;
+ const int depth_out_ = 16;
};
TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
- CreateSchedulerWithConv2Ds(); // init scheduler_.
-
- Costs zero_costs = Costs::ZeroCosts();
- std::unordered_map<string, NodeInfo> ops_executed;
- do {
- NodeInfo node_info = scheduler_->GetCurrNodeInfo();
- ops_executed[node_info.name] = node_info;
-
- // Check scheduling order: x and f before c0, and y and f before c1.
- if (node_info.name == "c0") {
- EXPECT_GT(ops_executed.count("x"), 0);
- EXPECT_GT(ops_executed.count("f"), 0);
- } else if (node_info.name == "c1") {
- EXPECT_GT(ops_executed.count("y"), 0);
- EXPECT_GT(ops_executed.count("f"), 0);
- }
- } while (scheduler_->MarkCurrNodeExecuted(zero_costs));
+ // Init.
+ CreateGrapplerItemWithConv2Ds();
+ InitScheduler();
+
+ // Run the scheduler.
+ auto ops_executed = RunScheduler(""); // Run all the nodes.
// [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
// executed.
@@ -132,5 +345,162 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
}
+
+TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
+ // Init.
+ CreateGrapplerItemWithAddN();
+ InitScheduler();
+
+ // Create a set of tensor properties.
+ std::vector<OpInfo::TensorProperties> output;
+ output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
+ output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
+ output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
+ output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
+ output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
+ output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
+
+ // port_num -1 is for control dependency: hard coded 4B.
+ EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
+
+ // Test valid outputs.
+ EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
+ EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
+ EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
+ EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
+
+ // Any uknown shape (-1) shall yield zero output size.
+ EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
+ EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
+
+ // Invalid port_num (though it may be an error) shall yield zero
+ // output size.
+ EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
+}
+
+TEST_F(VirtualSchedulerTest, MemoryUsage) {
+ // Init.
+ CreateGrapplerItemWithAddN();
+ InitScheduler();
+
+ // Run the scheduler.
+ RunScheduler("");
+
+ const auto& device_states = scheduler_->GetDeviceStates();
+ const auto& cpu_state = device_states.at(kCPU0);
+
+ // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
+ // is 4 x the input tensor size while executing the out node.
+ int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
+ const std::vector<string> expected_names = {"x", "y", "z", "w"};
+ EXPECT_EQ(expected_names.size() * one_input_node_size,
+ cpu_state.max_memory_usage);
+ ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
+ cpu_state.mem_usage_snapshot_at_peak);
+}
+
+TEST_F(VirtualSchedulerTest, ControlDependency) {
+ // Init.
+ CreateGrapplerItemWithControlDependency();
+ InitScheduler();
+
+ // Run the scheduler.
+ RunScheduler("");
+
+ const auto& device_states = scheduler_->GetDeviceStates();
+ const auto& cpu_state = device_states.at(kCPU0);
+
+ // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
+ // memory usage is when executing the final NoOp.
+ int64 one_input_node_size = 4; // control dependency
+ const std::vector<string> expected_names = {"x", "y", "z", "w",
+ "u", "v", "t"};
+ EXPECT_EQ(expected_names.size() * one_input_node_size,
+ cpu_state.max_memory_usage);
+ ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
+ cpu_state.mem_usage_snapshot_at_peak);
+}
+
+TEST_F(VirtualSchedulerTest, ComplexDependency) {
+ // Init.
+ CreateGrapplerItemWithBatchNorm();
+ InitScheduler();
+
+ // Run the scheduler.
+ RunScheduler("bn");
+
+ const auto& device_states = scheduler_->GetDeviceStates();
+ const auto& cpu_state = device_states.at(kCPU0);
+
+ // The graph is
+ // bn = FusedBatchNorm(x, scale, offset, mean, var)
+ // z1 = bn.y + x
+ // z2 = bn.var + bn.var
+ // z3 = bn.var + bn.var
+ // z4 = control dependency from bn.
+ // Note that bn.mean doesn't have any consumer.
+ const int x_size = batch_size_ * width_ * height_ * depth_in_;
+ int64 expected_size =
+ 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
+ 1 /* control dependency */);
+ EXPECT_EQ(expected_size, cpu_state.memory_usage);
+
+ // Nodes currrently in memory: bn's port -1, 0, and 2, and x's port 0.
+ std::set<std::pair<string, int>> nodes_in_memory;
+ std::transform(
+ cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
+ std::inserter(nodes_in_memory, nodes_in_memory.begin()),
+ [](const std::pair<const NodeDef*, int>& node_port) {
+ return std::make_pair(node_port.first->name(), node_port.second);
+ });
+ std::set<std::pair<string, int>> expected = {
+ std::make_pair("bn", -1), std::make_pair("bn", 0),
+ std::make_pair("bn", 2), std::make_pair("x", 0),
+ };
+ ExpectSetEq(expected, nodes_in_memory);
+
+ const auto& node_states = scheduler_->GetNodeStates();
+ const NodeState* bn_node = nullptr;
+ const NodeState* x_node = nullptr;
+ for (const auto& nodedef_node_state : node_states) {
+ const NodeDef* node = nodedef_node_state.first;
+ const NodeState& node_state = nodedef_node_state.second;
+ if (node->name() == "bn") {
+ bn_node = &node_state;
+ }
+ if (node->name() == "x") {
+ x_node = &node_state;
+ }
+ }
+ CHECK_NOTNULL(bn_node);
+ CHECK_NOTNULL(x_node);
+
+ ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
+ ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
+ ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
+ // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
+ ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
+}
+
+TEST_F(VirtualSchedulerTest, Variable) {
+ // Init.
+ CreateGrapplerItemWithConv2DAndVariable();
+ InitScheduler();
+
+ // Run the scheduler.
+ RunScheduler("");
+
+ const auto& device_states = scheduler_->GetDeviceStates();
+ const auto& cpu_state = device_states.at(kCPU0);
+
+ // There is one Conv2D that takes x and f, but f is variable, so it should be
+ // in persistent nodes.
+ // f is variable.
+ ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
+ cpu_state.persistent_nodes);
+ // Only x in peak memory usage snapshot.
+ ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
+ cpu_state.mem_usage_snapshot_at_peak);
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 6f5310f501..51146011b0 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -45,18 +45,33 @@ bool IsMerge(const NodeDef& node) {
return op == "Merge";
}
+bool IsNoOp(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "NoOp";
+}
+
bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2" ||
op == "PlaceholderWithDefault";
}
+bool IsRecv(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "_Recv";
+}
+
bool IsReduction(const NodeDef& node) {
const auto& op = node.op();
return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
op == "Mean" || op == "Any" || op == "All";
}
+bool IsSend(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "_Send";
+}
+
bool IsSwitch(const NodeDef& node) {
const auto& op = node.op();
return op == "Switch";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 483ef5c577..b2102c688d 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -26,8 +26,11 @@ bool IsConstant(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsIdentity(const NodeDef& node);
bool IsMerge(const NodeDef& node);
+bool IsNoOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
+bool IsRecv(const NodeDef& node);
bool IsReduction(const NodeDef& node);
+bool IsSend(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsVariable(const NodeDef& node);