diff options
-rw-r--r-- | tensorflow/c/c_api.h | 10 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.h | 14 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 48 | ||||
-rw-r--r-- | tensorflow/core/graph/tensor_id.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/tensor_id.h | 33 |
6 files changed, 104 insertions, 16 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c859434745..1eb75ef11f 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -894,7 +894,8 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); // Set the prefix to be prepended to the names of nodes in `graph_def` that will -// be imported into `graph`. +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); @@ -915,6 +916,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst); @@ -922,7 +924,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( // Set any imported nodes with control input `src_name` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references an operation already existing in the graph being imported -// into. +// into. `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); @@ -934,6 +936,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( // Add an output in `graph_def` to be returned via the `return_outputs` output // parameter of TF_GraphImportGraphDef(). If the output is remapped via an input // mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_ImportGraphDefOptions* opts, const char* oper_name, int index); @@ -943,7 +946,8 @@ TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); // Add an operation in `graph_def` to be returned via the `return_opers` output -// parameter of TF_GraphImportGraphDef(). +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( TF_ImportGraphDefOptions* opts, const char* oper_name); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 418a49b5db..add26f3b71 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -79,10 +79,10 @@ class GraphConstructor { : in.prefix + "/"), uniquify_names(in.uniquify_names), uniquify_prefix(in.uniquify_prefix), - input_map(in.input_map), + input_map(in.input_map.begin(), in.input_map.end()), skip_mapped_nodes(in.skip_mapped_nodes), control_dependencies(in.control_dependencies), - return_tensors(in.return_tensors), + return_tensors(in.return_tensors.begin(), in.return_tensors.end()), return_nodes(in.return_nodes), importing(true), validate_colocation_constraints(in.validate_colocation_constraints), @@ -121,7 +121,7 @@ class GraphConstructor { const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, std::vector<Node*>* return_nodes, - std::vector<TensorId>* missing_unused_input_map_keys) { + std::vector<SafeTensorId>* missing_unused_input_map_keys) { if (versions) { TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, TF_GRAPH_DEF_VERSION_MIN_PRODUCER, @@ -142,7 +142,7 @@ class GraphConstructor { ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, std::vector<Node*>* return_nodes, - std::vector<TensorId>* missing_unused_input_map_keys) + std::vector<SafeTensorId>* missing_unused_input_map_keys) : opts_(opts), node_defs_(node_defs), versions_(versions), @@ -251,7 +251,7 @@ class GraphConstructor { std::vector<Node*>* return_nodes_; // May be null. Not owned. - std::vector<TensorId>* missing_unused_input_map_keys_; + std::vector<SafeTensorId>* missing_unused_input_map_keys_; // Intermediate datastructure used to populate // `missing_unused_input_map_keys_`. diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index b03d655fe6..889359a68a 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -81,14 +81,14 @@ struct ImportGraphDefOptions { // corresponding to `input_map` keys will be remapped to the nodes in `g` // corresponding to the values. // - // Keys should not include `prefix`, i.e., a key TensorId's name should be the - // name as it originally appears in `gdef`. + // Keys should not include `prefix`, i.e., a key ID's name should be the name + // as it originally appears in `gdef`. // // If this is non-empty, ImportGraphDef must be called with the shape refiner // used to create the existing nodes referenced in `input_map`. // TODO(skyewm): can we remove this requirement? How do we access the original // shape refiner? - std::map<TensorId, TensorId> input_map; + std::map<SafeTensorId, SafeTensorId> input_map; // If true, nodes that will have all output edges removed because of // overrides in `input_map` will not be imported. @@ -107,12 +107,12 @@ struct ImportGraphDefOptions { // caller must pass a results object to `ImportGraphDef()`. The // `return_tensors` field will be populated with the imported nodes in `g`. // - // Entries should not include `prefix`, i.e., each TensorId's name should be - // the name as it originally appears in `gdef`. + // Entries should not include `prefix`, i.e., each ID's name should be the + // name as it originally appears in `gdef`. // // If this contains a tensor that's also being remapped via `input_map`, the // corresponding existing tensor in `g` will be returned. - std::vector<TensorId> return_tensors; + std::vector<SafeTensorId> return_tensors; // The names of nodes in `gdef` that will be returned via the // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list @@ -155,7 +155,7 @@ struct ImportGraphDefResults { // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and // weren't used as an input to any node in `gdef`. These keys are likely due // to typos, and callers may wish to treat their existence as an error. - std::vector<TensorId> missing_unused_input_map_keys; + std::vector<SafeTensorId> missing_unused_input_map_keys; }; // Adds the graph in GraphDef `gdef` into an existing Graph `*g`. diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 6309870190..e338840eeb 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1502,7 +1502,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) { opts, &refiner, &results); ASSERT_EQ(results.missing_unused_input_map_keys.size(), 1); - EXPECT_EQ(results.missing_unused_input_map_keys[0], TensorId("new_input", 2)); + EXPECT_EQ(results.missing_unused_input_map_keys[0], + SafeTensorId("new_input", 2)); } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithUnboundInput) { @@ -2748,6 +2749,51 @@ TEST_F(GraphConstructorTest, ImportGraphDef_NestedFunctionDefs) { EXPECT_EQ(outputs[0].scalar<float>()(), 3.0); } +// NOTE(skyewm): the C API depends on this behavior, but it's easier to write +// the test here. +TEST_F(GraphConstructorTest, ImportGraphDef_OptionsMemMgmt) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); + + // Populate graph with node we'll use in input map + ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(), + &refiner); + + // Add some strings to ImportGraphDefOptions and then rewrite the buffers. + char buf1[100]; + char buf2[100]; + char buf3[100]; + snprintf(buf1, sizeof(buf1), "input"); + snprintf(buf2, sizeof(buf2), "new_input"); + snprintf(buf3, sizeof(buf3), "t1"); + + ImportGraphDefOptions opts; + opts.input_map[TensorId(buf2, 0)] = TensorId(buf1, 0); + opts.return_tensors.push_back(TensorId(buf3, 0)); + + snprintf(buf1, sizeof(buf1), "xxxxxxxxxxxxxxxxxxxx"); + snprintf(buf2, sizeof(buf2), "xxxxxxxxxxxxxxxxxxxx"); + snprintf(buf3, sizeof(buf3), "xxxxxxxxxxxxxxxxxxxx"); + + // Import some new nodes using opts. + ImportGraphDefResults results; + ExpectOK( + R"EOF( + node { name: 'new_input' op: 'TestInput' } + node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } + )EOF", + opts, &refiner, &results); + + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("new_input")); + EXPECT_TRUE(HasNode("t1")); + + EXPECT_TRUE(HasEdge("input", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("new_input", 1, "t1", 1)); + + ASSERT_EQ(results.return_tensors.size(), 1); + EXPECT_EQ(results.return_tensors[0].first->name(), "t1"); +} + TEST_F(GraphConstructorTest, CopyGraph) { const int v = TF_GRAPH_DEF_VERSION; const int bad = v + 17; diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc index 8af1936d64..80c76df255 100644 --- a/tensorflow/core/graph/tensor_id.cc +++ b/tensorflow/core/graph/tensor_id.cc @@ -22,6 +22,11 @@ limitations under the License. namespace tensorflow { +TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {} + +SafeTensorId::SafeTensorId(const TensorId& id) + : SafeTensorId(id.first.ToString(), id.second) {} + TensorId ParseTensorName(const string& name) { return ParseTensorName(StringPiece(name.data(), name.size())); } diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h index c27120f7e6..bf13fc78a6 100644 --- a/tensorflow/core/graph/tensor_id.h +++ b/tensorflow/core/graph/tensor_id.h @@ -25,6 +25,8 @@ limitations under the License. namespace tensorflow { +struct SafeTensorId; + // Identifier for a tensor within a step. // first == operation_name, second == output_index // Note: does not own backing storage for name. @@ -34,6 +36,11 @@ struct TensorId : public std::pair<StringPiece, int> { // Inherit the set of constructors. using Base::pair; + // NOTE(skyewm): this is required on some platforms. I'm not sure why the + // using statement above isn't always sufficient. + TensorId() : Base() {} + TensorId(const SafeTensorId& id); + string ToString() const { if (second == Graph::kControlSlot) return strings::StrCat("^", first); return strings::StrCat(first, ":", second); @@ -50,6 +57,32 @@ struct TensorId : public std::pair<StringPiece, int> { TensorId ParseTensorName(const string& name); TensorId ParseTensorName(StringPiece name); +// Same as TensorId, except owns the backing storage for the op name. This makes +// the memory management simpler at the expense of a copy. +struct SafeTensorId : public std::pair<string, int> { + typedef std::pair<string, int> Base; + + // Inherit the set of constructors. + using Base::pair; + + // NOTE(skyewm): this is required on some platforms. I'm not sure why the + // using statement above isn't always sufficient. + SafeTensorId() : Base() {} + SafeTensorId(const TensorId& id); + + string ToString() const { + if (second == Graph::kControlSlot) return strings::StrCat("^", first); + return strings::StrCat(first, ":", second); + } + + struct Hasher { + public: + std::size_t operator()(const TensorId& x) const { + return Hash32(x.first.data(), x.first.size(), x.second); + } + }; +}; + } // namespace tensorflow #endif // TENSORFLOW_GRAPH_TENSOR_ID_H_ |