aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-06 18:15:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 18:22:33 -0700
commit172b8740236c88acf06fc9fa01e8ca52e5482edf (patch)
tree361569bf098414cf43c7a18144890aa0b3be5df6
parent3926a3aca125c42987adc602a5cb006d97b0261e (diff)
Make the pruning of the graph in ahead-of-time compilation
also prune out dependencies of fed tensors. Remove special handling of control dependencies from the introduced placeholders. For a compilation going to XLA, those control dependencies would not really be doing anything anyway (and the current code was only handling Placeholder, not PlaceholderV2). PiperOrigin-RevId: 161157126
-rw-r--r--tensorflow/compiler/aot/BUILD25
-rw-r--r--tensorflow/compiler/aot/compile.cc91
-rw-r--r--tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfadd.pbtxt31
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt54
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt25
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt26
-rw-r--r--tensorflow/compiler/aot/tfcompile.proto7
-rw-r--r--tensorflow/compiler/aot/tfcompile_util.cc117
-rw-r--r--tensorflow/compiler/aot/tfcompile_util.h15
10 files changed, 284 insertions, 109 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 12179b7637..b12b5318ec 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -156,7 +156,7 @@ tf_library(
)
# A test of tf_library that includes a graph with an unknown op, but where
-# the compilation works because the the unknown op is not needed for the fetches.
+# the compilation works because the unknown op is not needed for the fetches.
tf_library(
name = "test_graph_tfunknownop",
testonly = 1,
@@ -166,6 +166,29 @@ tf_library(
tags = ["manual"],
)
+# A test of tf_library that includes a graph with an unknown op, but where
+# the compilation works because the op between the unknown op and the
+# fetches is a feed.
+tf_library(
+ name = "test_graph_tfunknownop2",
+ testonly = 1,
+ config = "test_graph_tfunknownop2.config.pbtxt",
+ cpp_class = "UnknownOpAddComp",
+ graph = "test_graph_tfunknownop.pbtxt",
+ tags = ["manual"],
+)
+
+# A test of tf_library that includes a graph with an unknown op, but where
+# the compilation works because the unknown op is fed.
+tf_library(
+ name = "test_graph_tfunknownop3",
+ testonly = 1,
+ config = "test_graph_tfunknownop3.config.pbtxt",
+ cpp_class = "UnknownOpAddComp",
+ graph = "test_graph_tfunknownop.pbtxt",
+ tags = ["manual"],
+)
+
# Utility library for benchmark binaries, used by the *_benchmark rules that are
# added by the tfcompile bazel macro.
cc_library(
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 59ff14600b..51d08eaa01 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -78,66 +78,51 @@ Status DumpGraph(const MainFlags& flags, const string& name,
return WriteTextProto(Env::Default(), file, graph_def);
}
-string TensorIdToString(const TensorId& id) {
- return strings::StrCat(id.node_name(), ":", id.output_index());
-}
-
typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
-// of multiple edges. For each feed node, replaces all matching edges so that
-// they point from a new _Arg node instead.
+// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
+// tensor with a placeholder. For each feed tensor, replaces all edges so they
+// point from a new _Arg node instead.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
- const protobuf::RepeatedPtrField<Feed>& feeds) {
+ const protobuf::RepeatedPtrField<Feed>& feeds,
+ const std::unordered_map<string, string>& feed_remapping) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const Feed& feed = feeds[arg_index];
- const TensorId& id = feed.id();
- auto it = node_map.find(id.node_name());
- if (it == node_map.end()) {
- return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
- }
- const Node* feed_node = it->second;
- if (id.output_index() >= feed_node->num_outputs()) {
- return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
- ", output index should be < ",
- feed_node->num_outputs());
- }
- // TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
- // if we can determine it. That way the graph will be initialized with
- // whatever shapes we can infer, while the user can still explicitly specify
- // or override them.
+ // All feeds have been replaced by placeholders.
+ const int output_index = 0;
+
+ const auto remap_it = feed_remapping.find(TensorIdToString(feed.id()));
+ auto node_it = node_map.find(remap_it->second);
+ const Node* feed_node = node_it->second;
+
+ // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
+ // "_shape" attr if we can determine it. That way the graph will be
+ // initialized with whatever shapes we can infer, while the user can still
+ // explicitly specify or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
- .Attr("T", BaseType(feed_node->output_type(id.output_index())))
+ .Attr("T", BaseType(feed_node->output_type(output_index)))
.Attr("index", arg_index)
- .Attr(kFeedIdAttr, TensorIdToString(id))
+ .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
+
// Collects out-edges from the feed node that have a matching edge index;
- // these will be replaced with edges from the arg node instead. Also
- // replaces all control edges from Placeholder feed nodes; similar code
- // exists in subgraph::RewriteGraphForExecution.
- // TODO(toddw): Why only replace control edges from Placeholder?
+ // these will be replaced with edges from the arg node instead.
//
// We must collect the edges first and process them in a second pass, since
// removing the edge from the graph invalidates feed_node->out_edges.
std::vector<const Edge*> feed_edges;
for (const Edge* edge : feed_node->out_edges()) {
- if (edge->src_output() == id.output_index() ||
- (edge->src_output() == Graph::kControlSlot &&
- feed_node->type_string() == "Placeholder")) {
+ if (edge->src_output() == output_index) {
feed_edges.push_back(edge);
}
}
for (const Edge* edge : feed_edges) {
- if (edge->src_output() == id.output_index()) {
- graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
- } else {
- CHECK_EQ(edge->src_output(), Graph::kControlSlot);
- graph->AddControlEdge(arg_node, edge->dst());
- }
+ graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
}
}
@@ -179,13 +164,16 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
// execution to know the input and output args for the generated function.
-Status RewriteAndPruneGraph(Graph* graph, const Config& config,
- const MainFlags& flags) {
+Status RewriteAndPruneGraph(
+ Graph* graph, const Config& config,
+ const std::unordered_map<string, string>& feed_remapping,
+ const MainFlags& flags) {
NodeMap node_map;
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
- TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
+ TF_RETURN_IF_ERROR(
+ AddArgNodes(graph, node_map, config.feed(), feed_remapping));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
@@ -383,17 +371,28 @@ Status InitGraph(const GraphDef& graph_def, const Config& config,
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
std::unique_ptr<Graph> g(new Graph(flib_def));
- GraphDef copy_def;
+ // Replace references to fed tensors with references to newly added
+ // placeholders.
+ GraphDef first_copy_def = graph_def;
+
+ // Maps from name:port of a feed to the name:port of the placeholder to use.
+ std::unordered_map<string, string> feed_remapping;
+ TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
+ &feed_remapping, &first_copy_def));
// Prune the GraphDef first so that unknown ops that we aren't compiling get
// filtered out.
- TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, &copy_def));
+ GraphDef second_copy_def;
+ TF_RETURN_IF_ERROR(
+ PruneGraphDefInto(config, first_copy_def, &second_copy_def));
+
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
+ &second_copy_def, *g->op_registry(), 0 /*node_offset*/));
- TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy_def, *g->op_registry(),
- 0 /*node_offset*/));
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
+ second_copy_def, g.get()));
TF_RETURN_IF_ERROR(
- ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
- TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
+ RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
*graph = std::move(g);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
index 5625c0ab03..f2d9c34b2d 100644
--- a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
@@ -6,7 +6,7 @@ feed {
}
}
feed {
- id { node_name: "y_const" }
+ id { node_name: "y_reshape" }
shape {
dim { size: 1 }
}
diff --git a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt
index 91c900e06d..665c9fe287 100644
--- a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt
@@ -4,15 +4,7 @@ node {
attr {
key: "value"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
- dim {
- size: 1
- }
- }
- int_val: 1
- }
+ tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } int_val: 1 }
}
}
attr {
@@ -28,15 +20,7 @@ node {
attr {
key: "value"
value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
- dim {
- size: 1
- }
- }
- int_val: 2
- }
+ tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } int_val: 2 }
}
}
attr {
@@ -47,10 +31,19 @@ node {
}
}
node {
+ name : "y_reshape"
+ op : "Reshape"
+ input : "y_const"
+ input : "y_shape"
+ attr { key: "T" value { type: DT_INT32 } }
+ # Attribute TShape not specified; needs to be set to its default
+ # by tfcompile.
+}
+node {
name : "x_y_sum"
op : "Add"
input : "x_const"
- input : "y_const"
+ input : "y_reshape"
attr {
key : "T"
value {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
index 212ffbb5ff..48b881bb94 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
@@ -6,21 +6,12 @@ node {
value {
tensor {
dtype: DT_INT32
- tensor_shape {
- dim {
- size: 1
- }
- }
+ tensor_shape { dim { size: 1 } }
int_val: 1
}
}
}
- attr {
- key : "dtype"
- value {
- type : DT_INT32
- }
- }
+ attr { key : "dtype" value { type: DT_INT32 } }
}
node {
name : "y_const"
@@ -30,56 +21,37 @@ node {
value {
tensor {
dtype: DT_INT32
- tensor_shape {
- dim {
- size: 1
- }
- }
+ tensor_shape { dim { size: 1 } }
int_val: 2
}
}
}
- attr {
- key: "dtype"
- value {
- type: DT_INT32
- }
- }
+ attr { key: "dtype" value { type: DT_INT32 } }
}
node {
name : "x_y_sum"
op : "Add"
input : "x_const"
input : "y_const"
- attr {
- key : "T"
- value {
- type: DT_INT32
- }
- }
+ attr { key : "T" value { type: DT_INT32 } }
}
node {
name : "z"
op : "SomeUnknownOp"
input : "x_const"
- attr {
- key : "T"
- value {
- type: DT_INT32
- }
- }
+}
+node {
+ name : "z_identity"
+ op : "Identity"
+ input : "z:1"
+ attr { key : "T" value { type: DT_INT32 } }
}
node {
name : "x_z_sum"
op : "Add"
input : "x_const"
- input : "z"
- attr {
- key : "T"
- value {
- type: DT_INT32
- }
- }
+ input : "z_identity"
+ attr { key : "T" value { type: DT_INT32 } }
}
versions {
producer: 15
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
new file mode 100644
index 0000000000..7370ed370d
--- /dev/null
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
@@ -0,0 +1,25 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "z_identity"}
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "x_y_sum" }
+}
+fetch {
+ id { node_name: "x_z_sum" }
+}
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
new file mode 100644
index 0000000000..b2d7d54574
--- /dev/null
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
@@ -0,0 +1,26 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "z" output_index: 1}
+ shape {
+ dim { size: 1 }
+ }
+ type: DT_INT32
+}
+fetch {
+ id { node_name: "x_y_sum" }
+}
+fetch {
+ id { node_name: "x_z_sum" }
+}
diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/aot/tfcompile.proto
index be3f504350..cd83840d89 100644
--- a/tensorflow/compiler/aot/tfcompile.proto
+++ b/tensorflow/compiler/aot/tfcompile.proto
@@ -7,6 +7,7 @@ option java_multiple_files = true;
option java_package = "org.tensorflow.tfcompile";
import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
// index of a particular node in the graph. If the output of the named node
@@ -23,6 +24,12 @@ message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
+
+ // Optional data type. This is not normally required, as the graph itself
+ // contains this information. However, if the node being fed is an op that
+ // is not linked into the tfcompile binary, then the type cannot be inferred
+ // from the node; in this case, the type should be set here.
+ DataType type = 4;
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an
diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc
index 8774a02128..e6a4705b6c 100644
--- a/tensorflow/compiler/aot/tfcompile_util.cc
+++ b/tensorflow/compiler/aot/tfcompile_util.cc
@@ -20,12 +20,14 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tfcompile {
@@ -119,17 +121,115 @@ Status ValidateConfig(const Config& config) {
return Status::OK();
}
+Status AddPlaceholdersForFeeds(
+ const Config& config, const OpRegistryInterface* op_registry,
+ std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
+ struct PlaceholderInfo {
+ const Feed* feed = nullptr; // point to Feed in <config>.
+ string placeholder_name;
+ DataType data_type = DT_INVALID;
+ };
+
+ // Put each fed tensor into a map by name:port. A map is used for determinism
+ // when creating placeholders (genrules want deterministic output).
+ std::map<string, PlaceholderInfo> placeholder_info;
+ for (int i = 0; i < config.feed_size(); ++i) {
+ const Feed* feed = &config.feed(i);
+ const string name_port = TensorIdToString(feed->id());
+ auto& info = placeholder_info[name_port];
+ info.feed = feed;
+ info.placeholder_name = strings::StrCat(
+ "aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
+ (*feed_remapping)[name_port] = info.placeholder_name;
+ }
+
+ // Verify node exists and determine data type.
+ std::unordered_map<string, const NodeDef*> name_to_node;
+ for (int i = 0; i < graph_def->node_size(); ++i) {
+ name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
+ }
+ for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
+ PlaceholderInfo& info = it->second;
+ const TensorId& feed_id = info.feed->id();
+
+ // Find the existing node and determine data type.
+ auto node_it = name_to_node.find(feed_id.node_name());
+ if (node_it == name_to_node.end()) {
+ return errors::NotFound("Can't find feed node: ",
+ TensorIdToString(feed_id));
+ }
+ const NodeDef* existing = node_it->second;
+
+ if (info.feed->type() != DT_INVALID) {
+ info.data_type = info.feed->type();
+ } else {
+ // Build the node in order to infer its type.
+
+ // Must first add default attrs as well, so do this in a copied GraphDef.
+ GraphDef gd;
+ *gd.mutable_versions() = graph_def->versions();
+ *gd.add_node() = *existing;
+ TF_RETURN_IF_ERROR(
+ AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
+
+ // Now build the node from the copied node def.
+ Graph g(op_registry);
+ g.set_versions(graph_def->versions());
+ Status status;
+ Node* feed_node = g.AddNode(gd.node(0), &status);
+ TF_RETURN_IF_ERROR(status);
+ info.data_type =
+ BaseType(feed_node->output_type(info.feed->id().output_index()));
+ }
+ }
+
+ // Create placeholders. Note that we could avoid creating a placeholder for
+ // feeds which are already placeholders, but we omit that to avoid more cases
+ // in this code.
+ for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
+ const PlaceholderInfo& info = it->second;
+ NodeDef* d = graph_def->add_node();
+ d->set_name(info.placeholder_name);
+ d->set_op("PlaceholderV2");
+ auto& attr_map = *d->mutable_attr();
+ attr_map["dtype"].set_type(info.data_type);
+ *attr_map["shape"].mutable_shape() = info.feed->shape();
+ }
+
+ // Rewrite references to the fed tensors to refer to the placeholder.
+ for (int i = 0; i < graph_def->node_size(); ++i) {
+ NodeDef* node_def = graph_def->mutable_node(i);
+ for (int j = 0; j < node_def->input_size(); ++j) {
+ auto id = ParseTensorName(node_def->input(j));
+ auto it = placeholder_info.find(id.ToString());
+ if (it != placeholder_info.end()) {
+ node_def->set_input(j, it->second.placeholder_name);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
+ // Tensors needed for feeding.
+ std::set<std::pair<string, int>> feed_tensors;
+ for (const auto& feed_config : config.feed()) {
+ feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
+ feed_config.id().output_index()));
+ }
+
// Maps node name to reachability.
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
for (const NodeDef& node : in.node()) {
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
}
+ // Traverse.
std::queue<string> name_queue;
for (int i = 0; i < config.fetch_size(); ++i) {
name_queue.push(config.fetch(i).id().node_name());
@@ -149,8 +249,19 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
}
map_entry.first = true;
+ // Push input nodes of the currently visited node to name_queue.
for (const string& in_edge : map_entry.second->input()) {
- name_queue.push(ParseTensorName(in_edge).first.ToString());
+ auto id = ParseTensorName(in_edge);
+ const string node_name = id.first.ToString();
+ if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
+ feed_tensors.end()) {
+ name_queue.push(node_name);
+ } else {
+ // The input tensor is from an edge that is being fed. Therefore,
+ // we skip recursing down that edge, to avoid requiring nodes that
+ // may not be needed (note that the input node may still be added
+ // to name_queue later if one of its output edges is not being fed).
+ }
}
}
@@ -165,5 +276,9 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
return Status::OK();
}
+string TensorIdToString(const TensorId& id) {
+ return strings::StrCat(id.node_name(), ":", id.output_index());
+}
+
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h
index 84060c0761..365f7b0e7b 100644
--- a/tensorflow/compiler/aot/tfcompile_util.h
+++ b/tensorflow/compiler/aot/tfcompile_util.h
@@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+#include <unordered_map>
+
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -31,11 +34,23 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
Status ValidateConfig(const Config& config);
+// Modifies <graph_def> to include placeholders for each fed tensor, and
+// update references to the fed tensors to refer to the placeholders.
+// The existing nodes referenced by the feeds are not removed or modified
+// (except where their input edges are modified by the replacement of other
+// feeds).
+Status AddPlaceholdersForFeeds(
+ const Config& config, const OpRegistryInterface* op_registry,
+ std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
+
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
GraphDef* out);
+// Returns node:port for the given <id>.
+string TensorIdToString(const TensorId& id);
+
} // namespace tfcompile
} // namespace tensorflow