aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api_experimental.cc204
-rw-r--r--tensorflow/c/c_api_experimental.h13
-rw-r--r--tensorflow/c/c_test_util.cc10
-rw-r--r--tensorflow/c/c_test_util.h3
4 files changed, 172 insertions, 58 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index f6d8949bb0..eb17e16d3e 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -26,6 +26,7 @@ using tensorflow::Node;
using tensorflow::NodeBuilder;
using tensorflow::NodeDef;
using tensorflow::Status;
+using tensorflow::string;
namespace {
@@ -38,12 +39,28 @@ TF_Operation* ToTF_Operation(Node* node) {
// Graph rewrite algorithm (modeled after the python TPU graph rewrite path):
//
-// 1. For each input node I, feed it to a new TPUReplicatedInput node, which in
-// turn feeds a new Identity node N, and store the mapping I->N.
+// 1. For each input node I, with C being the consumer node of I's output:
//
-// 2. Rewrite all existing graph nodes by adding a attribute on TPU cluster. For
-// each node reading some input node I, rewire it to read from N instead based
-// on the I->N mapping in step #1.
+// a) When infeed is not specified, feed I to a new TPUReplicatedInput node
+// (both running on CPU), which in turn feeds a new Identity node N, and N feeds
+// C (both running on TPU).
+//
+// b) Otherwise, feed I to a new InfeedEnqueueTuple node IE, both running on
+// CPU. Also set an InfeedDequeueTuple node ID to feed C, both running on
+// TPU.
+//
+// In case b), if we have multiple input nodes, they all feed into the same
+// InfeedEnqueueTuple node, so that the graph has a single pair of infeed
+// enqueue and dequeue nodes. The list of output tensors from the dequeue node
+// can go to different consumer nodes. For example, say the original graph has
+// input nodes I1 and I2 respectively feeding nodes C1 and C2. After the rewrite
+// with infeed ops, we will have: I1 and I2 feed a single infeed enqueue node
+// IE, and a corresponding infeed dequeue node ID produces a list of two
+// tensors, respectively feeding C1 and C2.
+//
+// 2. Rewrite all existing graph nodes by adding an attribute on TPU
+// cluster. For each node C reading some input node I, rewire it to read from a
+// new input node generated in step #1 above.
//
// 3. For each output node O, feed it to a new Identity node, which in turn
// feeds a new TPUReplicatedOutput node, which in turn feeds a new Identity node
@@ -66,7 +83,8 @@ class GraphRewriter {
for (int i = 0; i < num_input_nodes; ++i) {
// Will fill in the value part later when we create the associated new
// input node.
- input_node_map_[input_nodes[i].oper->node.name()] = nullptr;
+ input_node_map_[input_nodes[i].oper->node.name()] =
+ NodeBuilder::NodeOut(nullptr, -1);
}
// Grab all existing nodes for the upcoming rewrite, before mutating the
@@ -84,19 +102,24 @@ class GraphRewriter {
// On success, sets `config_op` and `shutdown_op` to the corresponding
// "ConfigureDistributedTPU" and "ShutdownDistributedTPU" nodes added to the
// graph.
- tensorflow::Status Rewrite(TF_Output* new_output_nodes, TF_Output* config_op,
- TF_Output* shutdown_op)
+ tensorflow::Status Rewrite(TF_Output* new_output_nodes,
+ TF_Operation** infeed_enqueue_node,
+ TF_Output* config_op, TF_Output* shutdown_op)
EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) {
- TF_RETURN_IF_ERROR(ProcessInputNodes());
+ TF_RETURN_IF_ERROR(ProcessInputNodes(infeed_enqueue_node));
return RewriteGraphAndAddOutputNodes(new_output_nodes, config_op,
shutdown_op);
}
private:
- // Synthensizes new nodes for the input nodes, and creates a replicated
- // metadata node.
- tensorflow::Status ProcessInputNodes() EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) {
+ // Synthesizes new graph nodes (infeed enqueue or TPU replicated input
+ // nodes) for the input nodes, and creates a replicated metadata node.
+ //
+ // When `infeed_enqueue_node` is non-NULL and there are some input nodes,
+ // also adds the infeed dequeue node.
+ tensorflow::Status ProcessInputNodes(TF_Operation** infeed_enqueue_node)
+ EXCLUSIVE_LOCKS_REQUIRED(graph_->mu) {
Node* metadata_node;
TF_RETURN_IF_ERROR(
NodeBuilder(metadata_node_name_.c_str(), "TPUReplicateMetadata")
@@ -104,34 +127,85 @@ class GraphRewriter {
.Attr("_tpu_replicate", cluster_name_.c_str())
.Finalize(&graph_->graph, &metadata_node));
- for (int i = 0; i < input_node_map_.size(); ++i) {
- VLOG(1) << "Handling input node " << input_nodes_[i].oper->node.name();
- Node* replicated_input_node;
- {
- std::string replicated_input_name("TPUReplicate/input" +
- std::to_string(i));
- NodeBuilder::NodeOut input(&input_nodes_[i].oper->node,
- input_nodes_[i].index);
- std::vector<NodeBuilder::NodeOut> input_list;
- input_list.push_back(input);
+ Node* dequeue_node = nullptr;
+ // Be deterministic in the corner case where `use_infeed` below is false.
+ if (infeed_enqueue_node) *infeed_enqueue_node = nullptr;
+ const bool use_infeed =
+ infeed_enqueue_node != nullptr && !input_node_map_.empty();
+ if (use_infeed) {
+ std::vector<NodeBuilder::NodeOut> new_input_list;
+ new_input_list.reserve(input_node_map_.size());
+ std::vector<tensorflow::DataType> input_dtypes;
+ input_dtypes.reserve(input_node_map_.size());
+ std::vector<tensorflow::TensorShape> input_shapes;
+ input_shapes.reserve(input_node_map_.size());
+ for (int i = 0; i < input_node_map_.size(); ++i) {
+ Node& input_node = input_nodes_[i].oper->node;
+ new_input_list.push_back(
+ NodeBuilder::NodeOut(&input_node, input_nodes_[i].index));
+ input_dtypes.push_back(input_node.output_type(input_nodes_[i].index));
+ tensorflow::TensorShapeProto shape;
TF_RETURN_IF_ERROR(
- NodeBuilder(replicated_input_name.c_str(), "TPUReplicatedInput")
- // This op requires an input list.
- .Input(input_list)
- .Finalize(&graph_->graph, &replicated_input_node));
+ tensorflow::GetNodeAttr(input_node.attrs(), "shape", &shape));
+ VLOG(1) << "Input node " << i << " has shape " << shape.DebugString();
+ input_shapes.push_back(shape);
}
+ // Enqueue always runs on CPU.
+ Node* enqueue_node;
+ TF_RETURN_IF_ERROR(NodeBuilder("InfeedEnqueueTuple", "InfeedEnqueueTuple")
+ .Input(new_input_list)
+ .Device("/device:CPU:0")
+ .Attr("device_ordinal", 0)
+ .Attr("dtypes", input_dtypes)
+ .Attr("shapes", input_shapes)
+ .Finalize(&graph_->graph, &enqueue_node));
+ *infeed_enqueue_node = ToTF_Operation(enqueue_node);
+ // The dequeue node should be put onto the "_tpu_replicate" cluster.
+ TF_RETURN_IF_ERROR(
+ NodeBuilder("TPUReplicate/InfeedDequeueTuple", "InfeedDequeueTuple")
+ .ControlInput(metadata_node)
+ .Attr("_tpu_replicate", cluster_name_.c_str())
+ .Attr("dtypes", input_dtypes)
+ .Attr("shapes", input_shapes)
+ .Finalize(&graph_->graph, &dequeue_node));
+ }
- {
- Node* new_input_node;
- const std::string new_input_name("TPUReplicate/replicated_input_" +
- std::to_string(i));
- TF_RETURN_IF_ERROR(NodeBuilder(new_input_name.c_str(), "Identity")
- .Input(replicated_input_node, 0)
- .ControlInput(metadata_node)
- .Attr("_tpu_replicate", cluster_name_.c_str())
- .Finalize(&graph_->graph, &new_input_node));
- DCHECK_GT(input_node_map_.count(input_nodes_[i].oper->node.name()), 0);
- input_node_map_[input_nodes_[i].oper->node.name()] = new_input_node;
+ for (int i = 0; i < input_node_map_.size(); ++i) {
+ VLOG(1) << "Handling input node " << input_nodes_[i].oper->node.name();
+ if (use_infeed) {
+ DCHECK(dequeue_node);
+ input_node_map_[input_nodes_[i].oper->node.name()] =
+ NodeBuilder::NodeOut(dequeue_node, i);
+ } else {
+ Node* replicated_input_node;
+ {
+ std::string replicated_input_name("TPUReplicate/input" +
+ std::to_string(i));
+ NodeBuilder::NodeOut input(&input_nodes_[i].oper->node,
+ input_nodes_[i].index);
+ std::vector<NodeBuilder::NodeOut> input_list;
+ input_list.push_back(input);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(replicated_input_name.c_str(), "TPUReplicatedInput")
+ // This op requires an input list.
+ .Input(input_list)
+ .Finalize(&graph_->graph, &replicated_input_node));
+ }
+
+ {
+ Node* new_input_node;
+ const std::string new_input_name("TPUReplicate/replicated_input_" +
+ std::to_string(i));
+ TF_RETURN_IF_ERROR(NodeBuilder(new_input_name.c_str(), "Identity")
+ .Input(replicated_input_node, 0)
+ .ControlInput(metadata_node)
+ .Attr("_tpu_replicate", cluster_name_.c_str())
+ .Finalize(&graph_->graph, &new_input_node));
+ DCHECK_GT(input_node_map_.count(input_nodes_[i].oper->node.name()),
+ 0);
+ input_node_map_[input_nodes_[i].oper->node.name()] =
+ NodeBuilder::NodeOut(new_input_node, 0);
+ }
}
}
return Status::OK();
@@ -163,7 +237,9 @@ class GraphRewriter {
}
const NodeDef& old_def = n->def();
- Node* new_node;
+ // Let node C be the consumer of `n`'s output in the original graph.
+ // This new node will feed into C in the rewritten graph.
+ NodeBuilder::NodeOut new_node;
if (input_node_map_.count(n->name())) {
new_node = input_node_map_[n->name()];
} else {
@@ -173,10 +249,19 @@ class GraphRewriter {
new_def.set_name(new_node_name);
new_def.clear_input();
for (int i = 0; i < old_def.input_size(); ++i) {
- const std::string& old_input_name = old_def.input(i);
- const std::string new_input_name =
+ const string old_input_name = old_def.input(i);
+ // When there are multiple input nodes that get mapped to the same
+ // infeed dequeue node, use different output ports of the dequeue
+ // node. e.g. Say in the original graph, input I1 feeds C1, and I2
+ // feeds C2. After the rewrite, I1 and I2 both feed a new infeed
+ // enqueue node, and the corresponding dequeue node has its output
+ // port 0 feeding C1, and output port 1 feeding C2. Note C1 and C2
+ // could be the same node (e.g. an Add that takes 2 inputs).
+ const string new_input_name =
input_node_map_.count(old_input_name) > 0
- ? std::string(input_node_map_[old_input_name]->name())
+ ? tensorflow::strings::StrCat(
+ input_node_map_[old_input_name].node->name(), ":",
+ input_node_map_[old_input_name].index)
: "TPUReplicate/" + old_input_name;
new_def.add_input(new_input_name);
}
@@ -192,11 +277,12 @@ class GraphRewriter {
}
tensorflow::AddNodeAttr("_tpu_replicate", cluster_name_.c_str(),
&new_def);
- new_node = graph_->graph.AddNode(new_def, &s);
+ new_node = NodeBuilder::NodeOut(graph_->graph.AddNode(new_def, &s), 0);
if (!s.ok()) {
return s;
}
- VLOG(1) << "The rewritten node node is " << new_node->DebugString();
+ VLOG(1) << "The rewritten node node is "
+ << new_node.node->DebugString();
}
if (output_node_map_.count(n->name()) > 0) {
@@ -206,7 +292,17 @@ class GraphRewriter {
const PortIndexPair& pair = it->second;
Node* out_identity_node;
{
- VLOG(1) << "Handling its output port " << pair.port
+ // If this output node is also an input, use the input_node_map_'s
+ // stored port, which would also work for an infeed dequeue op.
+ // Otherwise use pair.port.
+ // An example of the former: Say the graph has input nodes I1 and
+ // I2, and the output nodes are also I1 and I2. In the rewritten
+ // graph with infeed, the 2 output nodes will both come from a
+ // single infeed dequeue node ID, with output ports respectively
+ // set to 0 and 1.
+ const int output_port =
+ input_node_map_.count(n->name()) ? new_node.index : pair.port;
+ VLOG(1) << "Handling its output port " << output_port
<< " at output index " << pair.index;
std::string output_node_name = "TPUReplicate/Identity";
if (pair.index > 0) {
@@ -214,7 +310,7 @@ class GraphRewriter {
}
TF_RETURN_IF_ERROR(
NodeBuilder(output_node_name.c_str(), "Identity")
- .Input(new_node, pair.port)
+ .Input(new_node.node, output_port)
.Device(!old_def.device().empty()
? old_def.device()
: tensorflow::strings::StrCat(
@@ -289,16 +385,18 @@ class GraphRewriter {
// Keep mappings from the current input nodes to newly created input nodes,
// which we will use to rewrite existing nodes that read these
// inputs. e.g. A node that reads input node PlaceHolder could be rewired to
- // read the created TPUReplicate/replicated_input_0 node.
- std::unordered_map<std::string, Node*> input_node_map_;
+ // read the created TPUReplicate/replicated_input_0 node or some output port
+ // of the created TPUReplicate/InfeedDequeueTuple node. Because of the latter
+ // case, we the map entries store NodeBuilder::NodeOut, and not just Node*.
+ std::unordered_map<std::string, NodeBuilder::NodeOut> input_node_map_;
std::vector<Node*> nodes_to_rewrite_;
// Map from name to set{(output port, output tensor idx)}.
- // e.g. Say ther are 3 output tensors, respectively produced by (node 0,
+ // e.g. Say there are 3 output tensors, respectively produced by (node 0,
// port 0), (node 0, port 1), (node 1, port 0). Then the mapping entries
// are: node 0 -> {(port 0, idx 0), (port 1, idx 1)} node 1 -> {(port 0, idx
- // 2)} Based on these mappings, we will generated 3 new output nodes.
+ // 2)} Based on these mappings, we will generate 3 new output nodes.
struct PortIndexPair {
int port;
int index;
@@ -331,7 +429,9 @@ TF_Output TF_SetupTPUExecution(TF_Session* session, int num_input_nodes,
const TF_Output* input_nodes,
int num_output_nodes,
const TF_Output* output_nodes,
- TF_Output* new_output_nodes, TF_Status* status) {
+ TF_Output* new_output_nodes,
+ TF_Operation** infeed_enqueue_node,
+ TF_Status* status) {
TF_Output config_op, shutdown_op;
{
auto graph = session->graph;
@@ -341,8 +441,8 @@ TF_Output TF_SetupTPUExecution(TF_Session* session, int num_input_nodes,
<< graph->graph.ToGraphDefDebug().DebugString();
GraphRewriter rewriter(graph, num_input_nodes, input_nodes,
num_output_nodes, output_nodes);
- status->status =
- rewriter.Rewrite(new_output_nodes, &config_op, &shutdown_op);
+ status->status = rewriter.Rewrite(new_output_nodes, infeed_enqueue_node,
+ &config_op, &shutdown_op);
if (!status->status.ok()) {
return shutdown_op;
}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index af65123131..2bad278d63 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -63,7 +63,15 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
// Sets up TPU execution, by rewriting the graph accordingly, and initializing
// TPU system.
//
-// On success, returns a shutdown node to be used in a subsequent
+// When `infeed_enqueue_node` is non-NULL and there are input tensors, rewrites
+// the graph by adding the relevant infeed enqueue/dequeue ops, and returns the
+// enqueue op in `infeed_enqueue_node` on success, so that user can run that
+// node and feed input tensors. When there are no input tensors,
+// `infeed_enqueue_node` is ignored, and user should not run that node later.
+// TODO(hongm): In this case, we currently only support input tensors of dim 0
+// shape. Lift that constraint.
+//
+// On success, also returns a shutdown node to be used in a subsequent
// TF_ShutdownTPUExecution(), and sets the new output nodes in
// `new_output_nodes` for caller to fetch from. Must be called exactly once
// before TF_SessionRun().
@@ -76,7 +84,8 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
TF_CAPI_EXPORT extern TF_Output TF_SetupTPUExecution(
TF_Session* session, int num_input_nodes, const TF_Output* input_nodes,
int num_output_nodes, const TF_Output* output_nodes,
- TF_Output* new_output_nodes, TF_Status* status);
+ TF_Output* new_output_nodes, TF_Operation** infeed_enqueue_node,
+ TF_Status* status);
// Shuts down TPU system. For any `session` where TF_SetupTPUExecution() has
// been successfully called, this call must be made exactly once before the
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 22f77e7b87..f3b28c1708 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -94,18 +94,22 @@ TF_Tensor* FloatTensor(float v) {
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
- TF_DataType dtype, TF_Operation** op) {
+ TF_DataType dtype, const std::vector<int64_t>& dims,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", dtype);
+ if (!dims.empty()) {
+ TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
+ }
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name,
- TF_DataType dtype) {
+ TF_DataType dtype, const std::vector<int64_t>& dims) {
TF_Operation* op;
- PlaceholderHelper(graph, s, name, dtype, &op);
+ PlaceholderHelper(graph, s, name, dtype, dims, &op);
return op;
}
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index d87c57fd51..cd19cf8d62 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -48,7 +48,8 @@ TF_Tensor* FloatTensor(float v);
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
const char* name = "feed",
- TF_DataType dtype = TF_INT32);
+ TF_DataType dtype = TF_INT32,
+ const std::vector<int64_t>& dims = {});
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const");