diff options
-rw-r--r-- | tensorflow/c/c_api_experimental.cc | 204 | ||||
-rw-r--r-- | tensorflow/c/c_api_experimental.h | 13 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.cc | 10 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.h | 3 |
4 files changed, 172 insertions, 58 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index f6d8949bb0..eb17e16d3e 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -26,6 +26,7 @@ using tensorflow::Node; using tensorflow::NodeBuilder; using tensorflow::NodeDef; using tensorflow::Status; +using tensorflow::string; namespace { @@ -38,12 +39,28 @@ TF_Operation* ToTF_Operation(Node* node) { // Graph rewrite algorithm (modeled after the python TPU graph rewrite path): // -// 1. For each input node I, feed it to a new TPUReplicatedInput node, which in -// turn feeds a new Identity node N, and store the mapping I->N. +// 1. For each input node I, with C being the consumer node of I's output: // -// 2. Rewrite all existing graph nodes by adding a attribute on TPU cluster. For -// each node reading some input node I, rewire it to read from N instead based -// on the I->N mapping in step #1. +// a) When infeed is not specified, feed I to a new TPUReplicatedInput node +// (both running on CPU), which in turn feeds a new Identity node N, and N feeds +// C (both running on TPU). +// +// b) Otherwise, feed I to a new InfeedEnqueueTuple node IE, both running on +// CPU. Also set an InfeedDequeueTuple node ID to feed C, both running on +// TPU. +// +// In case b), if we have multiple input nodes, they all feed into the same +// InfeedEnqueueTuple node, so that the graph has a single pair of infeed +// enqueue and dequeue nodes. The list of output tensors from the dequeue node +// can go to different consumer nodes. For example, say the original graph has +// input nodes I1 and I2 respectively feeding nodes C1 and C2. After the rewrite +// with infeed ops, we will have: I1 and I2 feed a single infeed enqueue node +// IE, and a corresponding infeed dequeue node ID produces a list of two +// tensors, respectively feeding C1 and C2. +// +// 2. Rewrite all existing graph nodes by adding an attribute on TPU +// cluster. For each node C reading some input node I, rewire it to read from a +// new input node generated in step #1 above. // // 3. For each output node O, feed it to a new Identity node, which in turn // feeds a new TPUReplicatedOutput node, which in turn feeds a new Identity node @@ -66,7 +83,8 @@ class GraphRewriter { for (int i = 0; i < num_input_nodes; ++i) { // Will fill in the value part later when we create the associated new // input node. - input_node_map_[input_nodes[i].oper->node.name()] = nullptr; + input_node_map_[input_nodes[i].oper->node.name()] = + NodeBuilder::NodeOut(nullptr, -1); } // Grab all existing nodes for the upcoming rewrite, before mutating the @@ -84,19 +102,24 @@ class GraphRewriter { // On success, sets `config_op` and `shutdown_op` to the corresponding // "ConfigureDistributedTPU" and "ShutdownDistributedTPU" nodes added to the // graph. - tensorflow::Status Rewrite(TF_Output* new_output_nodes, TF_Output* config_op, - TF_Output* shutdown_op) + tensorflow::Status Rewrite(TF_Output* new_output_nodes, + TF_Operation** infeed_enqueue_node, + TF_Output* config_op, TF_Output* shutdown_op) EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) { - TF_RETURN_IF_ERROR(ProcessInputNodes()); + TF_RETURN_IF_ERROR(ProcessInputNodes(infeed_enqueue_node)); return RewriteGraphAndAddOutputNodes(new_output_nodes, config_op, shutdown_op); } private: - // Synthensizes new nodes for the input nodes, and creates a replicated - // metadata node. - tensorflow::Status ProcessInputNodes() EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) { + // Synthesizes new graph nodes (infeed enqueue or TPU replicated input + // nodes) for the input nodes, and creates a replicated metadata node. + // + // When `infeed_enqueue_node` is non-NULL and there are some input nodes, + // also adds the infeed dequeue node. + tensorflow::Status ProcessInputNodes(TF_Operation** infeed_enqueue_node) + EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) { Node* metadata_node; TF_RETURN_IF_ERROR( NodeBuilder(metadata_node_name_.c_str(), "TPUReplicateMetadata") @@ -104,34 +127,85 @@ class GraphRewriter { .Attr("_tpu_replicate", cluster_name_.c_str()) .Finalize(&graph_->graph, &metadata_node)); - for (int i = 0; i < input_node_map_.size(); ++i) { - VLOG(1) << "Handling input node " << input_nodes_[i].oper->node.name(); - Node* replicated_input_node; - { - std::string replicated_input_name("TPUReplicate/input" + - std::to_string(i)); - NodeBuilder::NodeOut input(&input_nodes_[i].oper->node, - input_nodes_[i].index); - std::vector<NodeBuilder::NodeOut> input_list; - input_list.push_back(input); + Node* dequeue_node = nullptr; + // Be deterministic in the corner case where `use_infeed` below is false. + if (infeed_enqueue_node) *infeed_enqueue_node = nullptr; + const bool use_infeed = + infeed_enqueue_node != nullptr && !input_node_map_.empty(); + if (use_infeed) { + std::vector<NodeBuilder::NodeOut> new_input_list; + new_input_list.reserve(input_node_map_.size()); + std::vector<tensorflow::DataType> input_dtypes; + input_dtypes.reserve(input_node_map_.size()); + std::vector<tensorflow::TensorShape> input_shapes; + input_shapes.reserve(input_node_map_.size()); + for (int i = 0; i < input_node_map_.size(); ++i) { + Node& input_node = input_nodes_[i].oper->node; + new_input_list.push_back( + NodeBuilder::NodeOut(&input_node, input_nodes_[i].index)); + input_dtypes.push_back(input_node.output_type(input_nodes_[i].index)); + tensorflow::TensorShapeProto shape; TF_RETURN_IF_ERROR( - NodeBuilder(replicated_input_name.c_str(), "TPUReplicatedInput") - // This op requires an input list. - .Input(input_list) - .Finalize(&graph_->graph, &replicated_input_node)); + tensorflow::GetNodeAttr(input_node.attrs(), "shape", &shape)); + VLOG(1) << "Input node " << i << " has shape " << shape.DebugString(); + input_shapes.push_back(shape); } + // Enqueue always runs on CPU. + Node* enqueue_node; + TF_RETURN_IF_ERROR(NodeBuilder("InfeedEnqueueTuple", "InfeedEnqueueTuple") + .Input(new_input_list) + .Device("/device:CPU:0") + .Attr("device_ordinal", 0) + .Attr("dtypes", input_dtypes) + .Attr("shapes", input_shapes) + .Finalize(&graph_->graph, &enqueue_node)); + *infeed_enqueue_node = ToTF_Operation(enqueue_node); + // The dequeue node should be put onto the "_tpu_replicate" cluster. + TF_RETURN_IF_ERROR( + NodeBuilder("TPUReplicate/InfeedDequeueTuple", "InfeedDequeueTuple") + .ControlInput(metadata_node) + .Attr("_tpu_replicate", cluster_name_.c_str()) + .Attr("dtypes", input_dtypes) + .Attr("shapes", input_shapes) + .Finalize(&graph_->graph, &dequeue_node)); + } - { - Node* new_input_node; - const std::string new_input_name("TPUReplicate/replicated_input_" + - std::to_string(i)); - TF_RETURN_IF_ERROR(NodeBuilder(new_input_name.c_str(), "Identity") - .Input(replicated_input_node, 0) - .ControlInput(metadata_node) - .Attr("_tpu_replicate", cluster_name_.c_str()) - .Finalize(&graph_->graph, &new_input_node)); - DCHECK_GT(input_node_map_.count(input_nodes_[i].oper->node.name()), 0); - input_node_map_[input_nodes_[i].oper->node.name()] = new_input_node; + for (int i = 0; i < input_node_map_.size(); ++i) { + VLOG(1) << "Handling input node " << input_nodes_[i].oper->node.name(); + if (use_infeed) { + DCHECK(dequeue_node); + input_node_map_[input_nodes_[i].oper->node.name()] = + NodeBuilder::NodeOut(dequeue_node, i); + } else { + Node* replicated_input_node; + { + std::string replicated_input_name("TPUReplicate/input" + + std::to_string(i)); + NodeBuilder::NodeOut input(&input_nodes_[i].oper->node, + input_nodes_[i].index); + std::vector<NodeBuilder::NodeOut> input_list; + input_list.push_back(input); + TF_RETURN_IF_ERROR( + NodeBuilder(replicated_input_name.c_str(), "TPUReplicatedInput") + // This op requires an input list. + .Input(input_list) + .Finalize(&graph_->graph, &replicated_input_node)); + } + + { + Node* new_input_node; + const std::string new_input_name("TPUReplicate/replicated_input_" + + std::to_string(i)); + TF_RETURN_IF_ERROR(NodeBuilder(new_input_name.c_str(), "Identity") + .Input(replicated_input_node, 0) + .ControlInput(metadata_node) + .Attr("_tpu_replicate", cluster_name_.c_str()) + .Finalize(&graph_->graph, &new_input_node)); + DCHECK_GT(input_node_map_.count(input_nodes_[i].oper->node.name()), + 0); + input_node_map_[input_nodes_[i].oper->node.name()] = + NodeBuilder::NodeOut(new_input_node, 0); + } } } return Status::OK(); @@ -163,7 +237,9 @@ class GraphRewriter { } const NodeDef& old_def = n->def(); - Node* new_node; + // Let node C be the consumer of `n`'s output in the original graph. + // This new node will feed into C in the rewritten graph. + NodeBuilder::NodeOut new_node; if (input_node_map_.count(n->name())) { new_node = input_node_map_[n->name()]; } else { @@ -173,10 +249,19 @@ class GraphRewriter { new_def.set_name(new_node_name); new_def.clear_input(); for (int i = 0; i < old_def.input_size(); ++i) { - const std::string& old_input_name = old_def.input(i); - const std::string new_input_name = + const string old_input_name = old_def.input(i); + // When there are multiple input nodes that get mapped to the same + // infeed dequeue node, use different output ports of the dequeue + // node. e.g. Say in the original graph, input I1 feeds C1, and I2 + // feeds C2. After the rewrite, I1 and I2 both feed a new infeed + // enqueue node, and the corresponding dequeue node has its output + // port 0 feeding C1, and output port 1 feeding C2. Note C1 and C2 + // could be the same node (e.g. an Add that takes 2 inputs). + const string new_input_name = input_node_map_.count(old_input_name) > 0 - ? std::string(input_node_map_[old_input_name]->name()) + ? tensorflow::strings::StrCat( + input_node_map_[old_input_name].node->name(), ":", + input_node_map_[old_input_name].index) : "TPUReplicate/" + old_input_name; new_def.add_input(new_input_name); } @@ -192,11 +277,12 @@ class GraphRewriter { } tensorflow::AddNodeAttr("_tpu_replicate", cluster_name_.c_str(), &new_def); - new_node = graph_->graph.AddNode(new_def, &s); + new_node = NodeBuilder::NodeOut(graph_->graph.AddNode(new_def, &s), 0); if (!s.ok()) { return s; } - VLOG(1) << "The rewritten node node is " << new_node->DebugString(); + VLOG(1) << "The rewritten node node is " + << new_node.node->DebugString(); } if (output_node_map_.count(n->name()) > 0) { @@ -206,7 +292,17 @@ class GraphRewriter { const PortIndexPair& pair = it->second; Node* out_identity_node; { - VLOG(1) << "Handling its output port " << pair.port + // If this output node is also an input, use the input_node_map_'s + // stored port, which would also work for an infeed dequeue op. + // Otherwise use pair.port. + // An example of the former: Say the graph has input nodes I1 and + // I2, and the output nodes are also I1 and I2. In the rewritten + // graph with infeed, the 2 output nodes will both come from a + // single infeed dequeue node ID, with output ports respectively + // set to 0 and 1. + const int output_port = + input_node_map_.count(n->name()) ? new_node.index : pair.port; + VLOG(1) << "Handling its output port " << output_port << " at output index " << pair.index; std::string output_node_name = "TPUReplicate/Identity"; if (pair.index > 0) { @@ -214,7 +310,7 @@ class GraphRewriter { } TF_RETURN_IF_ERROR( NodeBuilder(output_node_name.c_str(), "Identity") - .Input(new_node, pair.port) + .Input(new_node.node, output_port) .Device(!old_def.device().empty() ? old_def.device() : tensorflow::strings::StrCat( @@ -289,16 +385,18 @@ class GraphRewriter { // Keep mappings from the current input nodes to newly created input nodes, // which we will use to rewrite existing nodes that read these // inputs. e.g. A node that reads input node PlaceHolder could be rewired to - // read the created TPUReplicate/replicated_input_0 node. - std::unordered_map<std::string, Node*> input_node_map_; + // read the created TPUReplicate/replicated_input_0 node or some output port + // of the created TPUReplicate/InfeedDequeueTuple node. Because of the latter + // case, we the map entries store NodeBuilder::NodeOut, and not just Node*. + std::unordered_map<std::string, NodeBuilder::NodeOut> input_node_map_; std::vector<Node*> nodes_to_rewrite_; // Map from name to set{(output port, output tensor idx)}. - // e.g. Say ther are 3 output tensors, respectively produced by (node 0, + // e.g. Say there are 3 output tensors, respectively produced by (node 0, // port 0), (node 0, port 1), (node 1, port 0). Then the mapping entries // are: node 0 -> {(port 0, idx 0), (port 1, idx 1)} node 1 -> {(port 0, idx - // 2)} Based on these mappings, we will generated 3 new output nodes. + // 2)} Based on these mappings, we will generate 3 new output nodes. struct PortIndexPair { int port; int index; @@ -331,7 +429,9 @@ TF_Output TF_SetupTPUExecution(TF_Session* session, int num_input_nodes, const TF_Output* input_nodes, int num_output_nodes, const TF_Output* output_nodes, - TF_Output* new_output_nodes, TF_Status* status) { + TF_Output* new_output_nodes, + TF_Operation** infeed_enqueue_node, + TF_Status* status) { TF_Output config_op, shutdown_op; { auto graph = session->graph; @@ -341,8 +441,8 @@ TF_Output TF_SetupTPUExecution(TF_Session* session, int num_input_nodes, << graph->graph.ToGraphDefDebug().DebugString(); GraphRewriter rewriter(graph, num_input_nodes, input_nodes, num_output_nodes, output_nodes); - status->status = - rewriter.Rewrite(new_output_nodes, &config_op, &shutdown_op); + status->status = rewriter.Rewrite(new_output_nodes, infeed_enqueue_node, + &config_op, &shutdown_op); if (!status->status.ok()) { return shutdown_op; } diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index af65123131..2bad278d63 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -63,7 +63,15 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, // Sets up TPU execution, by rewriting the graph accordingly, and initializing // TPU system. // -// On success, returns a shutdown node to be used in a subsequent +// When `infeed_enqueue_node` is non-NULL and there are input tensors, rewrites +// the graph by adding the relevant infeed enqueue/dequeue ops, and returns the +// enqueue op in `infeed_enqueue_node` on success, so that user can run that +// node and feed input tensors. When there are no input tensors, +// `infeed_enqueue_node` is ignored, and user should not run that node later. +// TODO(hongm): In this case, we currently only support input tensors of dim 0 +// shape. Lift that constraint. +// +// On success, also returns a shutdown node to be used in a subsequent // TF_ShutdownTPUExecution(), and sets the new output nodes in // `new_output_nodes` for caller to fetch from. Must be called exactly once // before TF_SessionRun(). @@ -76,7 +84,8 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, TF_CAPI_EXPORT extern TF_Output TF_SetupTPUExecution( TF_Session* session, int num_input_nodes, const TF_Output* input_nodes, int num_output_nodes, const TF_Output* output_nodes, - TF_Output* new_output_nodes, TF_Status* status); + TF_Output* new_output_nodes, TF_Operation** infeed_enqueue_node, + TF_Status* status); // Shuts down TPU system. For any `session` where TF_SetupTPUExecution() has // been successfully called, this call must be made exactly once before the diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 22f77e7b87..f3b28c1708 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -94,18 +94,22 @@ TF_Tensor* FloatTensor(float v) { // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, - TF_DataType dtype, TF_Operation** op) { + TF_DataType dtype, const std::vector<int64_t>& dims, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", dtype); + if (!dims.empty()) { + TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } *op = TF_FinishOperation(desc, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_NE(*op, nullptr); } TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name, - TF_DataType dtype) { + TF_DataType dtype, const std::vector<int64_t>& dims) { TF_Operation* op; - PlaceholderHelper(graph, s, name, dtype, &op); + PlaceholderHelper(graph, s, name, dtype, dims, &op); return op; } diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index d87c57fd51..cd19cf8d62 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -48,7 +48,8 @@ TF_Tensor* FloatTensor(float v); TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name = "feed", - TF_DataType dtype = TF_INT32); + TF_DataType dtype = TF_INT32, + const std::vector<int64_t>& dims = {}); TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); |