aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-06 11:59:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 12:04:17 -0700
commit27b341c800eab9c20d34394e497498d90f2b6dc9 (patch)
treef5c23063c0a88c411a81bb3909d37a1b51e0eb12
parentcab048ecde3c3567b188a93f62d51af3b6b5b078 (diff)
Automated g4 rollback of changelist 161087978
PiperOrigin-RevId: 161111023
-rw-r--r--tensorflow/compiler/aot/BUILD25
-rw-r--r--tensorflow/compiler/aot/compile.cc91
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt54
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt25
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt26
-rw-r--r--tensorflow/compiler/aot/tfcompile.proto7
-rw-r--r--tensorflow/compiler/aot/tfcompile_util.cc106
-rw-r--r--tensorflow/compiler/aot/tfcompile_util.h15
8 files changed, 89 insertions, 260 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index b12b5318ec..12179b7637 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -156,7 +156,7 @@ tf_library(
)
# A test of tf_library that includes a graph with an unknown op, but where
-# the compilation works because the unknown op is not needed for the fetches.
+# the compilation works because the the unknown op is not needed for the fetches.
tf_library(
name = "test_graph_tfunknownop",
testonly = 1,
@@ -166,29 +166,6 @@ tf_library(
tags = ["manual"],
)
-# A test of tf_library that includes a graph with an unknown op, but where
-# the compilation works because the op between the unknown op and the
-# fetches is a feed.
-tf_library(
- name = "test_graph_tfunknownop2",
- testonly = 1,
- config = "test_graph_tfunknownop2.config.pbtxt",
- cpp_class = "UnknownOpAddComp",
- graph = "test_graph_tfunknownop.pbtxt",
- tags = ["manual"],
-)
-
-# A test of tf_library that includes a graph with an unknown op, but where
-# the compilation works because the unknown op is fed.
-tf_library(
- name = "test_graph_tfunknownop3",
- testonly = 1,
- config = "test_graph_tfunknownop3.config.pbtxt",
- cpp_class = "UnknownOpAddComp",
- graph = "test_graph_tfunknownop.pbtxt",
- tags = ["manual"],
-)
-
# Utility library for benchmark binaries, used by the *_benchmark rules that are
# added by the tfcompile bazel macro.
cc_library(
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 51d08eaa01..59ff14600b 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -78,51 +78,66 @@ Status DumpGraph(const MainFlags& flags, const string& name,
return WriteTextProto(Env::Default(), file, graph_def);
}
+string TensorIdToString(const TensorId& id) {
+ return strings::StrCat(id.node_name(), ":", id.output_index());
+}
+
typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
-// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
-// tensor with a placeholder. For each feed tensor, replaces all edges so they
-// point from a new _Arg node instead.
+// of multiple edges. For each feed node, replaces all matching edges so that
+// they point from a new _Arg node instead.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
- const protobuf::RepeatedPtrField<Feed>& feeds,
- const std::unordered_map<string, string>& feed_remapping) {
+ const protobuf::RepeatedPtrField<Feed>& feeds) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const Feed& feed = feeds[arg_index];
- // All feeds have been replaced by placeholders.
- const int output_index = 0;
-
- const auto remap_it = feed_remapping.find(TensorIdToString(feed.id()));
- auto node_it = node_map.find(remap_it->second);
- const Node* feed_node = node_it->second;
-
- // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
- // "_shape" attr if we can determine it. That way the graph will be
- // initialized with whatever shapes we can infer, while the user can still
- // explicitly specify or override them.
+ const TensorId& id = feed.id();
+ auto it = node_map.find(id.node_name());
+ if (it == node_map.end()) {
+ return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
+ }
+ const Node* feed_node = it->second;
+ if (id.output_index() >= feed_node->num_outputs()) {
+ return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
+ ", output index should be < ",
+ feed_node->num_outputs());
+ }
+ // TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
+ // if we can determine it. That way the graph will be initialized with
+ // whatever shapes we can infer, while the user can still explicitly specify
+ // or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
- .Attr("T", BaseType(feed_node->output_type(output_index)))
+ .Attr("T", BaseType(feed_node->output_type(id.output_index())))
.Attr("index", arg_index)
- .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
+ .Attr(kFeedIdAttr, TensorIdToString(id))
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
-
// Collects out-edges from the feed node that have a matching edge index;
- // these will be replaced with edges from the arg node instead.
+ // these will be replaced with edges from the arg node instead. Also
+ // replaces all control edges from Placeholder feed nodes; similar code
+ // exists in subgraph::RewriteGraphForExecution.
+ // TODO(toddw): Why only replace control edges from Placeholder?
//
// We must collect the edges first and process them in a second pass, since
// removing the edge from the graph invalidates feed_node->out_edges.
std::vector<const Edge*> feed_edges;
for (const Edge* edge : feed_node->out_edges()) {
- if (edge->src_output() == output_index) {
+ if (edge->src_output() == id.output_index() ||
+ (edge->src_output() == Graph::kControlSlot &&
+ feed_node->type_string() == "Placeholder")) {
feed_edges.push_back(edge);
}
}
for (const Edge* edge : feed_edges) {
- graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
+ if (edge->src_output() == id.output_index()) {
+ graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
+ } else {
+ CHECK_EQ(edge->src_output(), Graph::kControlSlot);
+ graph->AddControlEdge(arg_node, edge->dst());
+ }
graph->RemoveEdge(edge);
}
}
@@ -164,16 +179,13 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
// execution to know the input and output args for the generated function.
-Status RewriteAndPruneGraph(
- Graph* graph, const Config& config,
- const std::unordered_map<string, string>& feed_remapping,
- const MainFlags& flags) {
+Status RewriteAndPruneGraph(Graph* graph, const Config& config,
+ const MainFlags& flags) {
NodeMap node_map;
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
- TF_RETURN_IF_ERROR(
- AddArgNodes(graph, node_map, config.feed(), feed_remapping));
+ TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
@@ -371,28 +383,17 @@ Status InitGraph(const GraphDef& graph_def, const Config& config,
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
std::unique_ptr<Graph> g(new Graph(flib_def));
- // Replace references to fed tensors with references to newly added
- // placeholders.
- GraphDef first_copy_def = graph_def;
-
- // Maps from name:port of a feed to the name:port of the placeholder to use.
- std::unordered_map<string, string> feed_remapping;
- TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
- &feed_remapping, &first_copy_def));
+ GraphDef copy_def;
// Prune the GraphDef first so that unknown ops that we aren't compiling get
// filtered out.
- GraphDef second_copy_def;
- TF_RETURN_IF_ERROR(
- PruneGraphDefInto(config, first_copy_def, &second_copy_def));
-
- TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
- &second_copy_def, *g->op_registry(), 0 /*node_offset*/));
+ TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, &copy_def));
- TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
- second_copy_def, g.get()));
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy_def, *g->op_registry(),
+ 0 /*node_offset*/));
TF_RETURN_IF_ERROR(
- RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
+ ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
+ TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
*graph = std::move(g);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
index 48b881bb94..212ffbb5ff 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
@@ -6,12 +6,21 @@ node {
value {
tensor {
dtype: DT_INT32
- tensor_shape { dim { size: 1 } }
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
int_val: 1
}
}
}
- attr { key : "dtype" value { type: DT_INT32 } }
+ attr {
+ key : "dtype"
+ value {
+ type : DT_INT32
+ }
+ }
}
node {
name : "y_const"
@@ -21,37 +30,56 @@ node {
value {
tensor {
dtype: DT_INT32
- tensor_shape { dim { size: 1 } }
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
int_val: 2
}
}
}
- attr { key: "dtype" value { type: DT_INT32 } }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
}
node {
name : "x_y_sum"
op : "Add"
input : "x_const"
input : "y_const"
- attr { key : "T" value { type: DT_INT32 } }
+ attr {
+ key : "T"
+ value {
+ type: DT_INT32
+ }
+ }
}
node {
name : "z"
op : "SomeUnknownOp"
input : "x_const"
-}
-node {
- name : "z_identity"
- op : "Identity"
- input : "z:1"
- attr { key : "T" value { type: DT_INT32 } }
+ attr {
+ key : "T"
+ value {
+ type: DT_INT32
+ }
+ }
}
node {
name : "x_z_sum"
op : "Add"
input : "x_const"
- input : "z_identity"
- attr { key : "T" value { type: DT_INT32 } }
+ input : "z"
+ attr {
+ key : "T"
+ value {
+ type: DT_INT32
+ }
+ }
}
versions {
producer: 15
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
deleted file mode 100644
index 7370ed370d..0000000000
--- a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
+++ /dev/null
@@ -1,25 +0,0 @@
-# Text form of tensorflow.tfcompile.Config proto.
-feed {
- id { node_name: "x_const" }
- shape {
- dim { size: 1 }
- }
-}
-feed {
- id { node_name: "y_const" }
- shape {
- dim { size: 1 }
- }
-}
-feed {
- id { node_name: "z_identity"}
- shape {
- dim { size: 1 }
- }
-}
-fetch {
- id { node_name: "x_y_sum" }
-}
-fetch {
- id { node_name: "x_z_sum" }
-}
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
deleted file mode 100644
index b2d7d54574..0000000000
--- a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
+++ /dev/null
@@ -1,26 +0,0 @@
-# Text form of tensorflow.tfcompile.Config proto.
-feed {
- id { node_name: "x_const" }
- shape {
- dim { size: 1 }
- }
-}
-feed {
- id { node_name: "y_const" }
- shape {
- dim { size: 1 }
- }
-}
-feed {
- id { node_name: "z" output_index: 1}
- shape {
- dim { size: 1 }
- }
- type: DT_INT32
-}
-fetch {
- id { node_name: "x_y_sum" }
-}
-fetch {
- id { node_name: "x_z_sum" }
-}
diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/aot/tfcompile.proto
index cd83840d89..be3f504350 100644
--- a/tensorflow/compiler/aot/tfcompile.proto
+++ b/tensorflow/compiler/aot/tfcompile.proto
@@ -7,7 +7,6 @@ option java_multiple_files = true;
option java_package = "org.tensorflow.tfcompile";
import "tensorflow/core/framework/tensor_shape.proto";
-import "tensorflow/core/framework/types.proto";
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
// index of a particular node in the graph. If the output of the named node
@@ -24,12 +23,6 @@ message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
-
- // Optional data type. This is not normally required, as the graph itself
- // contains this information. However, if the node being fed is an op that
- // is not linked into the tfcompile binary, then the type cannot be inferred
- // from the node; in this case, the type should be set here.
- DataType type = 4;
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an
diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc
index 4476baae63..8774a02128 100644
--- a/tensorflow/compiler/aot/tfcompile_util.cc
+++ b/tensorflow/compiler/aot/tfcompile_util.cc
@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tfcompile {
@@ -120,105 +119,17 @@ Status ValidateConfig(const Config& config) {
return Status::OK();
}
-Status AddPlaceholdersForFeeds(
- const Config& config, const OpRegistryInterface* op_registry,
- std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
- struct PlaceholderInfo {
- const Feed* feed = nullptr; // point to Feed in <config>.
- string placeholder_name;
- DataType data_type = DT_INVALID;
- };
-
- // Put each fed tensor into a map by name:port. A map is used for determinism
- // when creating placeholders (genrules want deterministic output).
- std::map<string, PlaceholderInfo> placeholder_info;
- for (int i = 0; i < config.feed_size(); ++i) {
- const Feed* feed = &config.feed(i);
- const string name_port = TensorIdToString(feed->id());
- auto& info = placeholder_info[name_port];
- info.feed = feed;
- info.placeholder_name = strings::StrCat(
- "aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
- (*feed_remapping)[name_port] = info.placeholder_name;
- }
-
- // Verify node exists and determine data type.
- std::unordered_map<string, const NodeDef*> name_to_node;
- for (int i = 0; i < graph_def->node_size(); ++i) {
- name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
- }
- for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
- PlaceholderInfo& info = it->second;
- const TensorId& feed_id = info.feed->id();
-
- // Find the existing node and determine data type.
- auto node_it = name_to_node.find(feed_id.node_name());
- if (node_it == name_to_node.end()) {
- return errors::NotFound("Can't find feed node: ",
- TensorIdToString(feed_id));
- }
- const NodeDef* existing = node_it->second;
-
- if (info.feed->type() != DT_INVALID) {
- info.data_type = info.feed->type();
- } else {
- // Build the node in order to infer its type.
- Graph g(op_registry);
- Status status;
- Node* feed_node = g.AddNode(*existing, &status);
- TF_RETURN_IF_ERROR(status);
- info.data_type =
- BaseType(feed_node->output_type(info.feed->id().output_index()));
- }
- }
-
- // Create placeholders. Note that we could avoid creating a placeholder for
- // feeds which are already placeholders, but we omit that to avoid more cases
- // in this code.
- for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
- const PlaceholderInfo& info = it->second;
- NodeDef* d = graph_def->add_node();
- d->set_name(info.placeholder_name);
- d->set_op("PlaceholderV2");
- auto& attr_map = *d->mutable_attr();
- attr_map["dtype"].set_type(info.data_type);
- *attr_map["shape"].mutable_shape() = info.feed->shape();
- }
-
- // Rewrite references to the fed tensors to refer to the placeholder.
- for (int i = 0; i < graph_def->node_size(); ++i) {
- NodeDef* node_def = graph_def->mutable_node(i);
- for (int j = 0; j < node_def->input_size(); ++j) {
- auto id = ParseTensorName(node_def->input(j));
- auto it = placeholder_info.find(id.ToString());
- if (it != placeholder_info.end()) {
- node_def->set_input(j, it->second.placeholder_name);
- }
- }
- }
-
- return Status::OK();
-}
-
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
- // Tensors needed for feeding.
- std::set<std::pair<string, int>> feed_tensors;
- for (const auto& feed_config : config.feed()) {
- feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
- feed_config.id().output_index()));
- }
-
// Maps node name to reachability.
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
for (const NodeDef& node : in.node()) {
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
}
- // Traverse.
std::queue<string> name_queue;
for (int i = 0; i < config.fetch_size(); ++i) {
name_queue.push(config.fetch(i).id().node_name());
@@ -238,19 +149,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
}
map_entry.first = true;
- // Push input nodes of the currently visited node to name_queue.
for (const string& in_edge : map_entry.second->input()) {
- auto id = ParseTensorName(in_edge);
- const string node_name = id.first.ToString();
- if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
- feed_tensors.end()) {
- name_queue.push(node_name);
- } else {
- // The input tensor is from an edge that is being fed. Therefore,
- // we skip recursing down that edge, to avoid requiring nodes that
- // may not be needed (note that the input node may still be added
- // to name_queue later if one of its output edges is not being fed).
- }
+ name_queue.push(ParseTensorName(in_edge).first.ToString());
}
}
@@ -265,9 +165,5 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
return Status::OK();
}
-string TensorIdToString(const TensorId& id) {
- return strings::StrCat(id.node_name(), ":", id.output_index());
-}
-
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h
index 365f7b0e7b..84060c0761 100644
--- a/tensorflow/compiler/aot/tfcompile_util.h
+++ b/tensorflow/compiler/aot/tfcompile_util.h
@@ -16,11 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
-#include <unordered_map>
-
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -34,23 +31,11 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
Status ValidateConfig(const Config& config);
-// Modifies <graph_def> to include placeholders for each fed tensor, and
-// update references to the fed tensors to refer to the placeholders.
-// The existing nodes referenced by the feeds are not removed or modified
-// (except where their input edges are modified by the replacement of other
-// feeds).
-Status AddPlaceholdersForFeeds(
- const Config& config, const OpRegistryInterface* op_registry,
- std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
-
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
GraphDef* out);
-// Returns node:port for the given <id>.
-string TensorIdToString(const TensorId& id);
-
} // namespace tfcompile
} // namespace tensorflow