diff options
Diffstat (limited to 'tensorflow/core/graph/graph_constructor.h')
-rw-r--r-- | tensorflow/core/graph/graph_constructor.h | 60 |
1 files changed, 42 insertions, 18 deletions
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index a8f9f2b245..6cd9347d96 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -72,8 +72,6 @@ struct ImportGraphDefOptions { // used to create the existing nodes referenced in `input_map`. // TODO(skyewm): can we remove this requirement? How do we access the original // shape refiner? - // - // TODO(skyewm): add functionality to retrieve unused `input_map` keys std::map<TensorId, TensorId> input_map; // If true, nodes that will have all output edges removed because of @@ -88,10 +86,10 @@ struct ImportGraphDefOptions { // other nodes in `gdef`. std::vector<string> control_dependencies; - // Tensors in `gdef` that will be returned via the `return_tensors` output - // parameter of `ImportGraphDef()`. If this list is non-empty, the caller must - // pass an empty vector to `ImportGraphDef()`. The vector will be populated - // with the imported nodes in `g`. + // Tensors in `gdef` that will be returned via the ImportGraphDefResults + // output parameter of `ImportGraphDef()`. If this list is non-empty, the + // caller must pass a results object to `ImportGraphDef()`. The + // `return_tensors` field will be populated with the imported nodes in `g`. // // Entries should not include `prefix`, i.e., each TensorId's name should be // the name as it originally appears in `gdef`. @@ -100,12 +98,43 @@ struct ImportGraphDefOptions { // corresponding existing tensor in `g` will be returned. std::vector<TensorId> return_tensors; + // The names of nodes in `gdef` that will be returned via the + // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list + // is non-empty, the caller must pass a results object to + // `ImportGraphDef()`. The `return_nodes` field will be populated with the + // imported nodes in `g`. + // + // Entries should not include `prefix`, i.e., each node's name should be the + // name as it originally appears in `gdef`. + // + // Unlike `return_tensors`, `input_map` has no effect on the nodes + // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. + // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. + std::vector<StringPiece> return_nodes; + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries // with ops that are not defined in the binary calling ImportGraphDef. // Similar to the producer_op_list argument to import_graph_def in the // python API. }; +// Optional results that may be returned by ImportGraphDef. +struct ImportGraphDefResults { + // The requested tensors associated with + // ImportGraphDefOptions::return_tensors. Note that the index may be different + // than the requested index if the returned tensor has been remapped according + // to `input_map`. + typedef int Index; + std::vector<std::pair<Node*, Index>> return_tensors; + + // The requested nodes associated with ImportGraphDefOptions::return_nodes. + std::vector<Node*> return_nodes; + + // Keys in ImportGraphDefOptions::input_map that weren't used as an input to + // any node in`gdef`. + std::vector<TensorId> unused_input_map_keys; +}; + // Adds the graph in GraphDef `gdef` into an existing Graph `*g`. // // On error, returns non-OK and leaves `*g` unmodified. @@ -115,21 +144,16 @@ struct ImportGraphDefOptions { // allows the caller to validate shapes of those nodes (since // ShapeRefiner::AddNode must be called in topological order). // -// Each `return_tensors` entry is the requested node and output index. The index -// is included in case the returned tensor has been remapped according to -// `input_map`. -// -// If `unused_input_map_keys` is non-null, it should be empty and will be -// populated with any keys in `opts.input_map` that aren't used as an input to -// any node in `gdef`. +// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is +// non-empty. It can also be set to fetch the unused input map keys. If it's +// non-null, all the vector fields must be empty. // // TODO(ashankar): Push this mechanism and get rid of Session::Extend() // as a means of enhancing an existing Graph. -extern Status ImportGraphDef( - const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g, - ShapeRefiner* refiner, - std::vector<std::pair<Node*, int>>* return_tensors = nullptr, - std::vector<TensorId>* unused_input_map_keys = nullptr); +extern Status ImportGraphDef(const ImportGraphDefOptions& opts, + const GraphDef& gdef, Graph* g, + ShapeRefiner* refiner, + ImportGraphDefResults* results = nullptr); // Make a copy of "src" into "*dest". // |