aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-03-19 22:29:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 22:33:26 -0700
commit1a8258e0593270a8e2370517dff8faafce40a687 (patch)
treee097a315ba6c21f8f2ac72c5af17564d73efbc4e /tensorflow/c/c_api_experimental.cc
parent407ddd1c0539cfc5d33ab2629230eab5a958b7d4 (diff)
Added infeed support for experimental C APIs associated with TPU graph rewrite.
This initial design of the C API is different from (and mostly higher level than) the python API counterparts for infeed, in that the python API has explicit graph construction APIs for generating infeed enqueue/dequeue ops (e.g. split_inputs_and_generate_enqueue_ops() and generate_dequeue_op()), while the C API takes an input graph and redirects all input nodes to feed the infeed enqueue. One requirement/restriction is that the input nodes in the TF graph (e.g. Placeholder) must specify their tensor shapes, for infeed enqueue and dequeue nodes to properly compile with XLA. The API for more general shape support will be designed and implemented later. PiperOrigin-RevId: 189693028
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc204
1 files changed, 152 insertions, 52 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index f6d8949bb0..eb17e16d3e 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -26,6 +26,7 @@ using tensorflow::Node;
using tensorflow::NodeBuilder;
using tensorflow::NodeDef;
using tensorflow::Status;
+using tensorflow::string;
namespace {
@@ -38,12 +39,28 @@ TF_Operation* ToTF_Operation(Node* node) {
// Graph rewrite algorithm (modeled after the python TPU graph rewrite path):
//
-// 1. For each input node I, feed it to a new TPUReplicatedInput node, which in
-// turn feeds a new Identity node N, and store the mapping I->N.
+// 1. For each input node I, with C being the consumer node of I's output:
//
-// 2. Rewrite all existing graph nodes by adding a attribute on TPU cluster. For
-// each node reading some input node I, rewire it to read from N instead based
-// on the I->N mapping in step #1.
+// a) When infeed is not specified, feed I to a new TPUReplicatedInput node
+// (both running on CPU), which in turn feeds a new Identity node N, and N feeds
+// C (both running on TPU).
+//
+// b) Otherwise, feed I to a new InfeedEnqueueTuple node IE, both running on
+// CPU. Also set an InfeedDequeueTuple node ID to feed C, both running on
+// TPU.
+//
+// In case b), if we have multiple input nodes, they all feed into the same
+// InfeedEnqueueTuple node, so that the graph has a single pair of infeed
+// enqueue and dequeue nodes. The list of output tensors from the dequeue node
+// can go to different consumer nodes. For example, say the original graph has
+// input nodes I1 and I2 respectively feeding nodes C1 and C2. After the rewrite
+// with infeed ops, we will have: I1 and I2 feed a single infeed enqueue node
+// IE, and a corresponding infeed dequeue node ID produces a list of two
+// tensors, respectively feeding C1 and C2.
+//
+// 2. Rewrite all existing graph nodes by adding an attribute on TPU
+// cluster. For each node C reading some input node I, rewire it to read from a
+// new input node generated in step #1 above.
//
// 3. For each output node O, feed it to a new Identity node, which in turn
// feeds a new TPUReplicatedOutput node, which in turn feeds a new Identity node
@@ -66,7 +83,8 @@ class GraphRewriter {
for (int i = 0; i < num_input_nodes; ++i) {
// Will fill in the value part later when we create the associated new
// input node.
- input_node_map_[input_nodes[i].oper->node.name()] = nullptr;
+ input_node_map_[input_nodes[i].oper->node.name()] =
+ NodeBuilder::NodeOut(nullptr, -1);
}
// Grab all existing nodes for the upcoming rewrite, before mutating the
@@ -84,19 +102,24 @@ class GraphRewriter {
// On success, sets `config_op` and `shutdown_op` to the corresponding
// "ConfigureDistributedTPU" and "ShutdownDistributedTPU" nodes added to the
// graph.
- tensorflow::Status Rewrite(TF_Output* new_output_nodes, TF_Output* config_op,
- TF_Output* shutdown_op)
+ tensorflow::Status Rewrite(TF_Output* new_output_nodes,
+ TF_Operation** infeed_enqueue_node,
+ TF_Output* config_op, TF_Output* shutdown_op)
EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) {
- TF_RETURN_IF_ERROR(ProcessInputNodes());
+ TF_RETURN_IF_ERROR(ProcessInputNodes(infeed_enqueue_node));
return RewriteGraphAndAddOutputNodes(new_output_nodes, config_op,
shutdown_op);
}
private:
- // Synthensizes new nodes for the input nodes, and creates a replicated
- // metadata node.
- tensorflow::Status ProcessInputNodes() EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) {
+ // Synthesizes new graph nodes (infeed enqueue or TPU replicated input
+ // nodes) for the input nodes, and creates a replicated metadata node.
+ //
+ // When `infeed_enqueue_node` is non-NULL and there are some input nodes,
+ // also adds the infeed dequeue node.
+ tensorflow::Status ProcessInputNodes(TF_Operation** infeed_enqueue_node)
+ EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) {
Node* metadata_node;
TF_RETURN_IF_ERROR(
NodeBuilder(metadata_node_name_.c_str(), "TPUReplicateMetadata")
@@ -104,34 +127,85 @@ class GraphRewriter {
.Attr("_tpu_replicate", cluster_name_.c_str())
.Finalize(&graph_->graph, &metadata_node));
- for (int i = 0; i < input_node_map_.size(); ++i) {
- VLOG(1) << "Handling input node " << input_nodes_[i].oper->node.name();
- Node* replicated_input_node;
- {
- std::string replicated_input_name("TPUReplicate/input" +
- std::to_string(i));
- NodeBuilder::NodeOut input(&input_nodes_[i].oper->node,
- input_nodes_[i].index);
- std::vector<NodeBuilder::NodeOut> input_list;
- input_list.push_back(input);
+ Node* dequeue_node = nullptr;
+ // Be deterministic in the corner case where `use_infeed` below is false.
+ if (infeed_enqueue_node) *infeed_enqueue_node = nullptr;
+ const bool use_infeed =
+ infeed_enqueue_node != nullptr && !input_node_map_.empty();
+ if (use_infeed) {
+ std::vector<NodeBuilder::NodeOut> new_input_list;
+ new_input_list.reserve(input_node_map_.size());
+ std::vector<tensorflow::DataType> input_dtypes;
+ input_dtypes.reserve(input_node_map_.size());
+ std::vector<tensorflow::TensorShape> input_shapes;
+ input_shapes.reserve(input_node_map_.size());
+ for (int i = 0; i < input_node_map_.size(); ++i) {
+ Node& input_node = input_nodes_[i].oper->node;
+ new_input_list.push_back(
+ NodeBuilder::NodeOut(&input_node, input_nodes_[i].index));
+ input_dtypes.push_back(input_node.output_type(input_nodes_[i].index));
+ tensorflow::TensorShapeProto shape;
TF_RETURN_IF_ERROR(
- NodeBuilder(replicated_input_name.c_str(), "TPUReplicatedInput")
- // This op requires an input list.
- .Input(input_list)
- .Finalize(&graph_->graph, &replicated_input_node));
+ tensorflow::GetNodeAttr(input_node.attrs(), "shape", &shape));
+ VLOG(1) << "Input node " << i << " has shape " << shape.DebugString();
+ input_shapes.push_back(shape);
}
+ // Enqueue always runs on CPU.
+ Node* enqueue_node;
+ TF_RETURN_IF_ERROR(NodeBuilder("InfeedEnqueueTuple", "InfeedEnqueueTuple")
+ .Input(new_input_list)
+ .Device("/device:CPU:0")
+ .Attr("device_ordinal", 0)
+ .Attr("dtypes", input_dtypes)
+ .Attr("shapes", input_shapes)
+ .Finalize(&graph_->graph, &enqueue_node));
+ *infeed_enqueue_node = ToTF_Operation(enqueue_node);
+ // The dequeue node should be put onto the "_tpu_replicate" cluster.
+ TF_RETURN_IF_ERROR(
+ NodeBuilder("TPUReplicate/InfeedDequeueTuple", "InfeedDequeueTuple")
+ .ControlInput(metadata_node)
+ .Attr("_tpu_replicate", cluster_name_.c_str())
+ .Attr("dtypes", input_dtypes)
+ .Attr("shapes", input_shapes)
+ .Finalize(&graph_->graph, &dequeue_node));
+ }
- {
- Node* new_input_node;
- const std::string new_input_name("TPUReplicate/replicated_input_" +
- std::to_string(i));
- TF_RETURN_IF_ERROR(NodeBuilder(new_input_name.c_str(), "Identity")
- .Input(replicated_input_node, 0)
- .ControlInput(metadata_node)
- .Attr("_tpu_replicate", cluster_name_.c_str())
- .Finalize(&graph_->graph, &new_input_node));
- DCHECK_GT(input_node_map_.count(input_nodes_[i].oper->node.name()), 0);
- input_node_map_[input_nodes_[i].oper->node.name()] = new_input_node;
+ for (int i = 0; i < input_node_map_.size(); ++i) {
+ VLOG(1) << "Handling input node " << input_nodes_[i].oper->node.name();
+ if (use_infeed) {
+ DCHECK(dequeue_node);
+ input_node_map_[input_nodes_[i].oper->node.name()] =
+ NodeBuilder::NodeOut(dequeue_node, i);
+ } else {
+ Node* replicated_input_node;
+ {
+ std::string replicated_input_name("TPUReplicate/input" +
+ std::to_string(i));
+ NodeBuilder::NodeOut input(&input_nodes_[i].oper->node,
+ input_nodes_[i].index);
+ std::vector<NodeBuilder::NodeOut> input_list;
+ input_list.push_back(input);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(replicated_input_name.c_str(), "TPUReplicatedInput")
+ // This op requires an input list.
+ .Input(input_list)
+ .Finalize(&graph_->graph, &replicated_input_node));
+ }
+
+ {
+ Node* new_input_node;
+ const std::string new_input_name("TPUReplicate/replicated_input_" +
+ std::to_string(i));
+ TF_RETURN_IF_ERROR(NodeBuilder(new_input_name.c_str(), "Identity")
+ .Input(replicated_input_node, 0)
+ .ControlInput(metadata_node)
+ .Attr("_tpu_replicate", cluster_name_.c_str())
+ .Finalize(&graph_->graph, &new_input_node));
+ DCHECK_GT(input_node_map_.count(input_nodes_[i].oper->node.name()),
+ 0);
+ input_node_map_[input_nodes_[i].oper->node.name()] =
+ NodeBuilder::NodeOut(new_input_node, 0);
+ }
}
}
return Status::OK();
@@ -163,7 +237,9 @@ class GraphRewriter {
}
const NodeDef& old_def = n->def();
- Node* new_node;
+ // Let node C be the consumer of `n`'s output in the original graph.
+ // This new node will feed into C in the rewritten graph.
+ NodeBuilder::NodeOut new_node;
if (input_node_map_.count(n->name())) {
new_node = input_node_map_[n->name()];
} else {
@@ -173,10 +249,19 @@ class GraphRewriter {
new_def.set_name(new_node_name);
new_def.clear_input();
for (int i = 0; i < old_def.input_size(); ++i) {
- const std::string& old_input_name = old_def.input(i);
- const std::string new_input_name =
+ const string old_input_name = old_def.input(i);
+ // When there are multiple input nodes that get mapped to the same
+ // infeed dequeue node, use different output ports of the dequeue
+ // node. e.g. Say in the original graph, input I1 feeds C1, and I2
+ // feeds C2. After the rewrite, I1 and I2 both feed a new infeed
+ // enqueue node, and the corresponding dequeue node has its output
+ // port 0 feeding C1, and output port 1 feeding C2. Note C1 and C2
+ // could be the same node (e.g. an Add that takes 2 inputs).
+ const string new_input_name =
input_node_map_.count(old_input_name) > 0
- ? std::string(input_node_map_[old_input_name]->name())
+ ? tensorflow::strings::StrCat(
+ input_node_map_[old_input_name].node->name(), ":",
+ input_node_map_[old_input_name].index)
: "TPUReplicate/" + old_input_name;
new_def.add_input(new_input_name);
}
@@ -192,11 +277,12 @@ class GraphRewriter {
}
tensorflow::AddNodeAttr("_tpu_replicate", cluster_name_.c_str(),
&new_def);
- new_node = graph_->graph.AddNode(new_def, &s);
+ new_node = NodeBuilder::NodeOut(graph_->graph.AddNode(new_def, &s), 0);
if (!s.ok()) {
return s;
}
- VLOG(1) << "The rewritten node node is " << new_node->DebugString();
+ VLOG(1) << "The rewritten node node is "
+ << new_node.node->DebugString();
}
if (output_node_map_.count(n->name()) > 0) {
@@ -206,7 +292,17 @@ class GraphRewriter {
const PortIndexPair& pair = it->second;
Node* out_identity_node;
{
- VLOG(1) << "Handling its output port " << pair.port
+ // If this output node is also an input, use the input_node_map_'s
+ // stored port, which would also work for an infeed dequeue op.
+ // Otherwise use pair.port.
+ // An example of the former: Say the graph has input nodes I1 and
+ // I2, and the output nodes are also I1 and I2. In the rewritten
+ // graph with infeed, the 2 output nodes will both come from a
+ // single infeed dequeue node ID, with output ports respectively
+ // set to 0 and 1.
+ const int output_port =
+ input_node_map_.count(n->name()) ? new_node.index : pair.port;
+ VLOG(1) << "Handling its output port " << output_port
<< " at output index " << pair.index;
std::string output_node_name = "TPUReplicate/Identity";
if (pair.index > 0) {
@@ -214,7 +310,7 @@ class GraphRewriter {
}
TF_RETURN_IF_ERROR(
NodeBuilder(output_node_name.c_str(), "Identity")
- .Input(new_node, pair.port)
+ .Input(new_node.node, output_port)
.Device(!old_def.device().empty()
? old_def.device()
: tensorflow::strings::StrCat(
@@ -289,16 +385,18 @@ class GraphRewriter {
// Keep mappings from the current input nodes to newly created input nodes,
// which we will use to rewrite existing nodes that read these
// inputs. e.g. A node that reads input node PlaceHolder could be rewired to
- // read the created TPUReplicate/replicated_input_0 node.
- std::unordered_map<std::string, Node*> input_node_map_;
+ // read the created TPUReplicate/replicated_input_0 node or some output port
+ // of the created TPUReplicate/InfeedDequeueTuple node. Because of the latter
+ // case, we the map entries store NodeBuilder::NodeOut, and not just Node*.
+ std::unordered_map<std::string, NodeBuilder::NodeOut> input_node_map_;
std::vector<Node*> nodes_to_rewrite_;
// Map from name to set{(output port, output tensor idx)}.
- // e.g. Say ther are 3 output tensors, respectively produced by (node 0,
+ // e.g. Say there are 3 output tensors, respectively produced by (node 0,
// port 0), (node 0, port 1), (node 1, port 0). Then the mapping entries
// are: node 0 -> {(port 0, idx 0), (port 1, idx 1)} node 1 -> {(port 0, idx
- // 2)} Based on these mappings, we will generated 3 new output nodes.
+ // 2)} Based on these mappings, we will generate 3 new output nodes.
struct PortIndexPair {
int port;
int index;
@@ -331,7 +429,9 @@ TF_Output TF_SetupTPUExecution(TF_Session* session, int num_input_nodes,
const TF_Output* input_nodes,
int num_output_nodes,
const TF_Output* output_nodes,
- TF_Output* new_output_nodes, TF_Status* status) {
+ TF_Output* new_output_nodes,
+ TF_Operation** infeed_enqueue_node,
+ TF_Status* status) {
TF_Output config_op, shutdown_op;
{
auto graph = session->graph;
@@ -341,8 +441,8 @@ TF_Output TF_SetupTPUExecution(TF_Session* session, int num_input_nodes,
<< graph->graph.ToGraphDefDebug().DebugString();
GraphRewriter rewriter(graph, num_input_nodes, input_nodes,
num_output_nodes, output_nodes);
- status->status =
- rewriter.Rewrite(new_output_nodes, &config_op, &shutdown_op);
+ status->status = rewriter.Rewrite(new_output_nodes, infeed_enqueue_node,
+ &config_op, &shutdown_op);
if (!status->status.ok()) {
return shutdown_op;
}