diff options
author | 2017-07-06 18:15:00 -0700 | |
---|---|---|
committer | 2017-07-06 18:22:33 -0700 | |
commit | 172b8740236c88acf06fc9fa01e8ca52e5482edf (patch) | |
tree | 361569bf098414cf43c7a18144890aa0b3be5df6 | |
parent | 3926a3aca125c42987adc602a5cb006d97b0261e (diff) |
Make the pruning of the graph in ahead-of-time compilation
also prune out dependencies of fed tensors.
Remove special handling of control dependencies from the introduced
placeholders. For a compilation going to XLA, those control dependencies would
not really be doing anything anyway (and the current code was only handling
Placeholder, not PlaceholderV2).
PiperOrigin-RevId: 161157126
-rw-r--r-- | tensorflow/compiler/aot/BUILD | 25 | ||||
-rw-r--r-- | tensorflow/compiler/aot/compile.cc | 91 | ||||
-rw-r--r-- | tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/compiler/aot/test_graph_tfadd.pbtxt | 31 | ||||
-rw-r--r-- | tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt | 54 | ||||
-rw-r--r-- | tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt | 25 | ||||
-rw-r--r-- | tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt | 26 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tfcompile.proto | 7 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tfcompile_util.cc | 117 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tfcompile_util.h | 15 |
10 files changed, 284 insertions, 109 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 12179b7637..b12b5318ec 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -156,7 +156,7 @@ tf_library( ) # A test of tf_library that includes a graph with an unknown op, but where -# the compilation works because the the unknown op is not needed for the fetches. +# the compilation works because the unknown op is not needed for the fetches. tf_library( name = "test_graph_tfunknownop", testonly = 1, @@ -166,6 +166,29 @@ tf_library( tags = ["manual"], ) +# A test of tf_library that includes a graph with an unknown op, but where +# the compilation works because the op between the unknown op and the +# fetches is a feed. +tf_library( + name = "test_graph_tfunknownop2", + testonly = 1, + config = "test_graph_tfunknownop2.config.pbtxt", + cpp_class = "UnknownOpAddComp", + graph = "test_graph_tfunknownop.pbtxt", + tags = ["manual"], +) + +# A test of tf_library that includes a graph with an unknown op, but where +# the compilation works because the unknown op is fed. +tf_library( + name = "test_graph_tfunknownop3", + testonly = 1, + config = "test_graph_tfunknownop3.config.pbtxt", + cpp_class = "UnknownOpAddComp", + graph = "test_graph_tfunknownop.pbtxt", + tags = ["manual"], +) + # Utility library for benchmark binaries, used by the *_benchmark rules that are # added by the tfcompile bazel macro. cc_library( diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 59ff14600b..51d08eaa01 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -78,66 +78,51 @@ Status DumpGraph(const MainFlags& flags, const string& name, return WriteTextProto(Env::Default(), file, graph_def); } -string TensorIdToString(const TensorId& id) { - return strings::StrCat(id.node_name(), ":", id.output_index()); -} - typedef std::unordered_map<string, Node*> NodeMap; // Each feed id identifies the positional output of some node, which may consist -// of multiple edges. For each feed node, replaces all matching edges so that -// they point from a new _Arg node instead. +// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed +// tensor with a placeholder. For each feed tensor, replaces all edges so they +// point from a new _Arg node instead. Status AddArgNodes(Graph* graph, const NodeMap& node_map, - const protobuf::RepeatedPtrField<Feed>& feeds) { + const protobuf::RepeatedPtrField<Feed>& feeds, + const std::unordered_map<string, string>& feed_remapping) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const Feed& feed = feeds[arg_index]; - const TensorId& id = feed.id(); - auto it = node_map.find(id.node_name()); - if (it == node_map.end()) { - return errors::NotFound("Can't find feed id: ", TensorIdToString(id)); - } - const Node* feed_node = it->second; - if (id.output_index() >= feed_node->num_outputs()) { - return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id), - ", output index should be < ", - feed_node->num_outputs()); - } - // TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr - // if we can determine it. That way the graph will be initialized with - // whatever shapes we can infer, while the user can still explicitly specify - // or override them. + // All feeds have been replaced by placeholders. + const int output_index = 0; + + const auto remap_it = feed_remapping.find(TensorIdToString(feed.id())); + auto node_it = node_map.find(remap_it->second); + const Node* feed_node = node_it->second; + + // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a + // "_shape" attr if we can determine it. That way the graph will be + // initialized with whatever shapes we can infer, while the user can still + // explicitly specify or override them. Node* arg_node = nullptr; TF_RETURN_IF_ERROR( NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) - .Attr("T", BaseType(feed_node->output_type(id.output_index()))) + .Attr("T", BaseType(feed_node->output_type(output_index))) .Attr("index", arg_index) - .Attr(kFeedIdAttr, TensorIdToString(id)) + .Attr(kFeedIdAttr, TensorIdToString(feed.id())) .Attr(kShapeAttr, TensorShape(feed.shape())) .Attr(kDebugNameAttr, feed.name()) .Finalize(graph, &arg_node)); + // Collects out-edges from the feed node that have a matching edge index; - // these will be replaced with edges from the arg node instead. Also - // replaces all control edges from Placeholder feed nodes; similar code - // exists in subgraph::RewriteGraphForExecution. - // TODO(toddw): Why only replace control edges from Placeholder? + // these will be replaced with edges from the arg node instead. // // We must collect the edges first and process them in a second pass, since // removing the edge from the graph invalidates feed_node->out_edges. std::vector<const Edge*> feed_edges; for (const Edge* edge : feed_node->out_edges()) { - if (edge->src_output() == id.output_index() || - (edge->src_output() == Graph::kControlSlot && - feed_node->type_string() == "Placeholder")) { + if (edge->src_output() == output_index) { feed_edges.push_back(edge); } } for (const Edge* edge : feed_edges) { - if (edge->src_output() == id.output_index()) { - graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); - } else { - CHECK_EQ(edge->src_output(), Graph::kControlSlot); - graph->AddControlEdge(arg_node, edge->dst()); - } + graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); } } @@ -179,13 +164,16 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg // nodes, and outputs flow to _Retval nodes. This allows the symbolic graph // execution to know the input and output args for the generated function. -Status RewriteAndPruneGraph(Graph* graph, const Config& config, - const MainFlags& flags) { +Status RewriteAndPruneGraph( + Graph* graph, const Config& config, + const std::unordered_map<string, string>& feed_remapping, + const MainFlags& flags) { NodeMap node_map; for (Node* n : graph->nodes()) { node_map[n->name()] = n; } - TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed())); + TF_RETURN_IF_ERROR( + AddArgNodes(graph, node_map, config.feed(), feed_remapping)); std::unordered_set<const Node*> retval_nodes; TF_RETURN_IF_ERROR( AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); @@ -383,17 +371,28 @@ Status InitGraph(const GraphDef& graph_def, const Config& config, FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); std::unique_ptr<Graph> g(new Graph(flib_def)); - GraphDef copy_def; + // Replace references to fed tensors with references to newly added + // placeholders. + GraphDef first_copy_def = graph_def; + + // Maps from name:port of a feed to the name:port of the placeholder to use. + std::unordered_map<string, string> feed_remapping; + TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), + &feed_remapping, &first_copy_def)); // Prune the GraphDef first so that unknown ops that we aren't compiling get // filtered out. - TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, ©_def)); + GraphDef second_copy_def; + TF_RETURN_IF_ERROR( + PruneGraphDefInto(config, first_copy_def, &second_copy_def)); + + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( + &second_copy_def, *g->op_registry(), 0 /*node_offset*/)); - TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(), - 0 /*node_offset*/)); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), + second_copy_def, g.get())); TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get())); - TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags)); + RewriteAndPruneGraph(g.get(), config, feed_remapping, flags)); *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt index 5625c0ab03..f2d9c34b2d 100644 --- a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt @@ -6,7 +6,7 @@ feed { } } feed { - id { node_name: "y_const" } + id { node_name: "y_reshape" } shape { dim { size: 1 } } diff --git a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt index 91c900e06d..665c9fe287 100644 --- a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt @@ -4,15 +4,7 @@ node { attr { key: "value" value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } + tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } int_val: 1 } } } attr { @@ -28,15 +20,7 @@ node { attr { key: "value" value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } + tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } int_val: 2 } } } attr { @@ -47,10 +31,19 @@ node { } } node { + name : "y_reshape" + op : "Reshape" + input : "y_const" + input : "y_shape" + attr { key: "T" value { type: DT_INT32 } } + # Attribute TShape not specified; needs to be set to its default + # by tfcompile. +} +node { name : "x_y_sum" op : "Add" input : "x_const" - input : "y_const" + input : "y_reshape" attr { key : "T" value { diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt index 212ffbb5ff..48b881bb94 100644 --- a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt @@ -6,21 +6,12 @@ node { value { tensor { dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } + tensor_shape { dim { size: 1 } } int_val: 1 } } } - attr { - key : "dtype" - value { - type : DT_INT32 - } - } + attr { key : "dtype" value { type: DT_INT32 } } } node { name : "y_const" @@ -30,56 +21,37 @@ node { value { tensor { dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } + tensor_shape { dim { size: 1 } } int_val: 2 } } } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } + attr { key: "dtype" value { type: DT_INT32 } } } node { name : "x_y_sum" op : "Add" input : "x_const" input : "y_const" - attr { - key : "T" - value { - type: DT_INT32 - } - } + attr { key : "T" value { type: DT_INT32 } } } node { name : "z" op : "SomeUnknownOp" input : "x_const" - attr { - key : "T" - value { - type: DT_INT32 - } - } +} +node { + name : "z_identity" + op : "Identity" + input : "z:1" + attr { key : "T" value { type: DT_INT32 } } } node { name : "x_z_sum" op : "Add" input : "x_const" - input : "z" - attr { - key : "T" - value { - type: DT_INT32 - } - } + input : "z_identity" + attr { key : "T" value { type: DT_INT32 } } } versions { producer: 15 diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt new file mode 100644 index 0000000000..7370ed370d --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt @@ -0,0 +1,25 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "z_identity"} + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} +fetch { + id { node_name: "x_z_sum" } +} diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt new file mode 100644 index 0000000000..b2d7d54574 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt @@ -0,0 +1,26 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "z" output_index: 1} + shape { + dim { size: 1 } + } + type: DT_INT32 +} +fetch { + id { node_name: "x_y_sum" } +} +fetch { + id { node_name: "x_z_sum" } +} diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/aot/tfcompile.proto index be3f504350..cd83840d89 100644 --- a/tensorflow/compiler/aot/tfcompile.proto +++ b/tensorflow/compiler/aot/tfcompile.proto @@ -7,6 +7,7 @@ option java_multiple_files = true; option java_package = "org.tensorflow.tfcompile"; import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; // TensorId identifies a tensor in a TensorFlow graph, by specifying the output // index of a particular node in the graph. If the output of the named node @@ -23,6 +24,12 @@ message Feed { TensorId id = 1; TensorShapeProto shape = 2; string name = 3; // Optional name for generated code. + + // Optional data type. This is not normally required, as the graph itself + // contains this information. However, if the node being fed is an op that + // is not linked into the tfcompile binary, then the type cannot be inferred + // from the node; in this case, the type should be set here. + DataType type = 4; }; // Fetch represents a single fetch tensor in the graph, which corresponds to an diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc index 8774a02128..e6a4705b6c 100644 --- a/tensorflow/compiler/aot/tfcompile_util.cc +++ b/tensorflow/compiler/aot/tfcompile_util.cc @@ -20,12 +20,14 @@ limitations under the License. #include <unordered_map> #include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tfcompile { @@ -119,17 +121,115 @@ Status ValidateConfig(const Config& config) { return Status::OK(); } +Status AddPlaceholdersForFeeds( + const Config& config, const OpRegistryInterface* op_registry, + std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) { + struct PlaceholderInfo { + const Feed* feed = nullptr; // point to Feed in <config>. + string placeholder_name; + DataType data_type = DT_INVALID; + }; + + // Put each fed tensor into a map by name:port. A map is used for determinism + // when creating placeholders (genrules want deterministic output). + std::map<string, PlaceholderInfo> placeholder_info; + for (int i = 0; i < config.feed_size(); ++i) { + const Feed* feed = &config.feed(i); + const string name_port = TensorIdToString(feed->id()); + auto& info = placeholder_info[name_port]; + info.feed = feed; + info.placeholder_name = strings::StrCat( + "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); + (*feed_remapping)[name_port] = info.placeholder_name; + } + + // Verify node exists and determine data type. + std::unordered_map<string, const NodeDef*> name_to_node; + for (int i = 0; i < graph_def->node_size(); ++i) { + name_to_node[graph_def->node(i).name()] = &graph_def->node(i); + } + for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { + PlaceholderInfo& info = it->second; + const TensorId& feed_id = info.feed->id(); + + // Find the existing node and determine data type. + auto node_it = name_to_node.find(feed_id.node_name()); + if (node_it == name_to_node.end()) { + return errors::NotFound("Can't find feed node: ", + TensorIdToString(feed_id)); + } + const NodeDef* existing = node_it->second; + + if (info.feed->type() != DT_INVALID) { + info.data_type = info.feed->type(); + } else { + // Build the node in order to infer its type. + + // Must first add default attrs as well, so do this in a copied GraphDef. + GraphDef gd; + *gd.mutable_versions() = graph_def->versions(); + *gd.add_node() = *existing; + TF_RETURN_IF_ERROR( + AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); + + // Now build the node from the copied node def. + Graph g(op_registry); + g.set_versions(graph_def->versions()); + Status status; + Node* feed_node = g.AddNode(gd.node(0), &status); + TF_RETURN_IF_ERROR(status); + info.data_type = + BaseType(feed_node->output_type(info.feed->id().output_index())); + } + } + + // Create placeholders. Note that we could avoid creating a placeholder for + // feeds which are already placeholders, but we omit that to avoid more cases + // in this code. + for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { + const PlaceholderInfo& info = it->second; + NodeDef* d = graph_def->add_node(); + d->set_name(info.placeholder_name); + d->set_op("PlaceholderV2"); + auto& attr_map = *d->mutable_attr(); + attr_map["dtype"].set_type(info.data_type); + *attr_map["shape"].mutable_shape() = info.feed->shape(); + } + + // Rewrite references to the fed tensors to refer to the placeholder. + for (int i = 0; i < graph_def->node_size(); ++i) { + NodeDef* node_def = graph_def->mutable_node(i); + for (int j = 0; j < node_def->input_size(); ++j) { + auto id = ParseTensorName(node_def->input(j)); + auto it = placeholder_info.find(id.ToString()); + if (it != placeholder_info.end()) { + node_def->set_input(j, it->second.placeholder_name); + } + } + } + + return Status::OK(); +} + Status PruneGraphDefInto(const Config& config, const GraphDef& in, GraphDef* out) { *out = in; out->clear_node(); + // Tensors needed for feeding. + std::set<std::pair<string, int>> feed_tensors; + for (const auto& feed_config : config.feed()) { + feed_tensors.insert(std::make_pair(feed_config.id().node_name(), + feed_config.id().output_index())); + } + // Maps node name to reachability. std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name; for (const NodeDef& node : in.node()) { node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node); } + // Traverse. std::queue<string> name_queue; for (int i = 0; i < config.fetch_size(); ++i) { name_queue.push(config.fetch(i).id().node_name()); @@ -149,8 +249,19 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in, } map_entry.first = true; + // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { - name_queue.push(ParseTensorName(in_edge).first.ToString()); + auto id = ParseTensorName(in_edge); + const string node_name = id.first.ToString(); + if (feed_tensors.find(std::make_pair(node_name, id.second)) == + feed_tensors.end()) { + name_queue.push(node_name); + } else { + // The input tensor is from an edge that is being fed. Therefore, + // we skip recursing down that edge, to avoid requiring nodes that + // may not be needed (note that the input node may still be added + // to name_queue later if one of its output edges is not being fed). + } } } @@ -165,5 +276,9 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in, return Status::OK(); } +string TensorIdToString(const TensorId& id) { + return strings::StrCat(id.node_name(), ":", id.output_index()); +} + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h index 84060c0761..365f7b0e7b 100644 --- a/tensorflow/compiler/aot/tfcompile_util.h +++ b/tensorflow/compiler/aot/tfcompile_util.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ #define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ +#include <unordered_map> + #include "tensorflow/compiler/aot/tfcompile.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -31,11 +34,23 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg); // ValidateConfig returns OK iff config is valid. Status ValidateConfig(const Config& config); +// Modifies <graph_def> to include placeholders for each fed tensor, and +// update references to the fed tensors to refer to the placeholders. +// The existing nodes referenced by the feeds are not removed or modified +// (except where their input edges are modified by the replacement of other +// feeds). +Status AddPlaceholdersForFeeds( + const Config& config, const OpRegistryInterface* op_registry, + std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def); + // Returns in <out> a copy of <in>, pruned to only include fetches from // <config>. Status PruneGraphDefInto(const Config& config, const GraphDef& in, GraphDef* out); +// Returns node:port for the given <id>. +string TensorIdToString(const TensorId& id); + } // namespace tfcompile } // namespace tensorflow |