aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_constructor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_constructor.h')
-rw-r--r--tensorflow/core/graph/graph_constructor.h60
1 files changed, 42 insertions, 18 deletions
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index a8f9f2b245..6cd9347d96 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -72,8 +72,6 @@ struct ImportGraphDefOptions {
// used to create the existing nodes referenced in `input_map`.
// TODO(skyewm): can we remove this requirement? How do we access the original
// shape refiner?
- //
- // TODO(skyewm): add functionality to retrieve unused `input_map` keys
std::map<TensorId, TensorId> input_map;
// If true, nodes that will have all output edges removed because of
@@ -88,10 +86,10 @@ struct ImportGraphDefOptions {
// other nodes in `gdef`.
std::vector<string> control_dependencies;
- // Tensors in `gdef` that will be returned via the `return_tensors` output
- // parameter of `ImportGraphDef()`. If this list is non-empty, the caller must
- // pass an empty vector to `ImportGraphDef()`. The vector will be populated
- // with the imported nodes in `g`.
+ // Tensors in `gdef` that will be returned via the ImportGraphDefResults
+ // output parameter of `ImportGraphDef()`. If this list is non-empty, the
+ // caller must pass a results object to `ImportGraphDef()`. The
+ // `return_tensors` field will be populated with the imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each TensorId's name should be
// the name as it originally appears in `gdef`.
@@ -100,12 +98,43 @@ struct ImportGraphDefOptions {
// corresponding existing tensor in `g` will be returned.
std::vector<TensorId> return_tensors;
+ // The names of nodes in `gdef` that will be returned via the
+ // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
+ // is non-empty, the caller must pass a results object to
+ // `ImportGraphDef()`. The `return_nodes` field will be populated with the
+ // imported nodes in `g`.
+ //
+ // Entries should not include `prefix`, i.e., each node's name should be the
+ // name as it originally appears in `gdef`.
+ //
+ // Unlike `return_tensors`, `input_map` has no effect on the nodes
+ // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
+ // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
+ std::vector<StringPiece> return_nodes;
+
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.
// Similar to the producer_op_list argument to import_graph_def in the
// python API.
};
+// Optional results that may be returned by ImportGraphDef.
+struct ImportGraphDefResults {
+ // The requested tensors associated with
+ // ImportGraphDefOptions::return_tensors. Note that the index may be different
+ // than the requested index if the returned tensor has been remapped according
+ // to `input_map`.
+ typedef int Index;
+ std::vector<std::pair<Node*, Index>> return_tensors;
+
+ // The requested nodes associated with ImportGraphDefOptions::return_nodes.
+ std::vector<Node*> return_nodes;
+
+ // Keys in ImportGraphDefOptions::input_map that weren't used as an input to
+ // any node in`gdef`.
+ std::vector<TensorId> unused_input_map_keys;
+};
+
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
//
// On error, returns non-OK and leaves `*g` unmodified.
@@ -115,21 +144,16 @@ struct ImportGraphDefOptions {
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
-// Each `return_tensors` entry is the requested node and output index. The index
-// is included in case the returned tensor has been remapped according to
-// `input_map`.
-//
-// If `unused_input_map_keys` is non-null, it should be empty and will be
-// populated with any keys in `opts.input_map` that aren't used as an input to
-// any node in `gdef`.
+// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
+// non-empty. It can also be set to fetch the unused input map keys. If it's
+// non-null, all the vector fields must be empty.
//
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
// as a means of enhancing an existing Graph.
-extern Status ImportGraphDef(
- const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g,
- ShapeRefiner* refiner,
- std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
- std::vector<TensorId>* unused_input_map_keys = nullptr);
+extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
+ const GraphDef& gdef, Graph* g,
+ ShapeRefiner* refiner,
+ ImportGraphDefResults* results = nullptr);
// Make a copy of "src" into "*dest".
//