aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api.h10
-rw-r--r--tensorflow/core/graph/graph_constructor.cc10
-rw-r--r--tensorflow/core/graph/graph_constructor.h14
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc48
-rw-r--r--tensorflow/core/graph/tensor_id.cc5
-rw-r--r--tensorflow/core/graph/tensor_id.h33
6 files changed, 104 insertions, 16 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index c859434745..1eb75ef11f 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -894,7 +894,8 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions(
TF_ImportGraphDefOptions* opts);
// Set the prefix to be prepended to the names of nodes in `graph_def` that will
-// be imported into `graph`.
+// be imported into `graph`. `prefix` is copied and has no lifetime
+// requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix(
TF_ImportGraphDefOptions* opts, const char* prefix);
@@ -915,6 +916,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix(
// Set any imported nodes with input `src_name:src_index` to have that input
// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
// `dst` references a node already existing in the graph being imported into.
+// `src_name` is copied and has no lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping(
TF_ImportGraphDefOptions* opts, const char* src_name, int src_index,
TF_Output dst);
@@ -922,7 +924,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping(
// Set any imported nodes with control input `src_name` to have that input
// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
// `dst` references an operation already existing in the graph being imported
-// into.
+// into. `src_name` is copied and has no lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency(
TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst);
@@ -934,6 +936,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency(
// Add an output in `graph_def` to be returned via the `return_outputs` output
// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input
// mapping, the corresponding existing tensor in `graph` will be returned.
+// `oper_name` is copied and has no lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput(
TF_ImportGraphDefOptions* opts, const char* oper_name, int index);
@@ -943,7 +946,8 @@ TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs(
const TF_ImportGraphDefOptions* opts);
// Add an operation in `graph_def` to be returned via the `return_opers` output
-// parameter of TF_GraphImportGraphDef().
+// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no
+// lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation(
TF_ImportGraphDefOptions* opts, const char* oper_name);
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 418a49b5db..add26f3b71 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -79,10 +79,10 @@ class GraphConstructor {
: in.prefix + "/"),
uniquify_names(in.uniquify_names),
uniquify_prefix(in.uniquify_prefix),
- input_map(in.input_map),
+ input_map(in.input_map.begin(), in.input_map.end()),
skip_mapped_nodes(in.skip_mapped_nodes),
control_dependencies(in.control_dependencies),
- return_tensors(in.return_tensors),
+ return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
return_nodes(in.return_nodes),
importing(true),
validate_colocation_constraints(in.validate_colocation_constraints),
@@ -121,7 +121,7 @@ class GraphConstructor {
const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
- std::vector<TensorId>* missing_unused_input_map_keys) {
+ std::vector<SafeTensorId>* missing_unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
@@ -142,7 +142,7 @@ class GraphConstructor {
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
- std::vector<TensorId>* missing_unused_input_map_keys)
+ std::vector<SafeTensorId>* missing_unused_input_map_keys)
: opts_(opts),
node_defs_(node_defs),
versions_(versions),
@@ -251,7 +251,7 @@ class GraphConstructor {
std::vector<Node*>* return_nodes_;
// May be null. Not owned.
- std::vector<TensorId>* missing_unused_input_map_keys_;
+ std::vector<SafeTensorId>* missing_unused_input_map_keys_;
// Intermediate datastructure used to populate
// `missing_unused_input_map_keys_`.
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index b03d655fe6..889359a68a 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -81,14 +81,14 @@ struct ImportGraphDefOptions {
// corresponding to `input_map` keys will be remapped to the nodes in `g`
// corresponding to the values.
//
- // Keys should not include `prefix`, i.e., a key TensorId's name should be the
- // name as it originally appears in `gdef`.
+ // Keys should not include `prefix`, i.e., a key ID's name should be the name
+ // as it originally appears in `gdef`.
//
// If this is non-empty, ImportGraphDef must be called with the shape refiner
// used to create the existing nodes referenced in `input_map`.
// TODO(skyewm): can we remove this requirement? How do we access the original
// shape refiner?
- std::map<TensorId, TensorId> input_map;
+ std::map<SafeTensorId, SafeTensorId> input_map;
// If true, nodes that will have all output edges removed because of
// overrides in `input_map` will not be imported.
@@ -107,12 +107,12 @@ struct ImportGraphDefOptions {
// caller must pass a results object to `ImportGraphDef()`. The
// `return_tensors` field will be populated with the imported nodes in `g`.
//
- // Entries should not include `prefix`, i.e., each TensorId's name should be
- // the name as it originally appears in `gdef`.
+ // Entries should not include `prefix`, i.e., each ID's name should be the
+ // name as it originally appears in `gdef`.
//
// If this contains a tensor that's also being remapped via `input_map`, the
// corresponding existing tensor in `g` will be returned.
- std::vector<TensorId> return_tensors;
+ std::vector<SafeTensorId> return_tensors;
// The names of nodes in `gdef` that will be returned via the
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
@@ -155,7 +155,7 @@ struct ImportGraphDefResults {
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
// weren't used as an input to any node in `gdef`. These keys are likely due
// to typos, and callers may wish to treat their existence as an error.
- std::vector<TensorId> missing_unused_input_map_keys;
+ std::vector<SafeTensorId> missing_unused_input_map_keys;
};
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 6309870190..e338840eeb 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1502,7 +1502,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts, &refiner, &results);
ASSERT_EQ(results.missing_unused_input_map_keys.size(), 1);
- EXPECT_EQ(results.missing_unused_input_map_keys[0], TensorId("new_input", 2));
+ EXPECT_EQ(results.missing_unused_input_map_keys[0],
+ SafeTensorId("new_input", 2));
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithUnboundInput) {
@@ -2748,6 +2749,51 @@ TEST_F(GraphConstructorTest, ImportGraphDef_NestedFunctionDefs) {
EXPECT_EQ(outputs[0].scalar<float>()(), 3.0);
}
+// NOTE(skyewm): the C API depends on this behavior, but it's easier to write
+// the test here.
+TEST_F(GraphConstructorTest, ImportGraphDef_OptionsMemMgmt) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ // Populate graph with node we'll use in input map
+ ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(),
+ &refiner);
+
+ // Add some strings to ImportGraphDefOptions and then rewrite the buffers.
+ char buf1[100];
+ char buf2[100];
+ char buf3[100];
+ snprintf(buf1, sizeof(buf1), "input");
+ snprintf(buf2, sizeof(buf2), "new_input");
+ snprintf(buf3, sizeof(buf3), "t1");
+
+ ImportGraphDefOptions opts;
+ opts.input_map[TensorId(buf2, 0)] = TensorId(buf1, 0);
+ opts.return_tensors.push_back(TensorId(buf3, 0));
+
+ snprintf(buf1, sizeof(buf1), "xxxxxxxxxxxxxxxxxxxx");
+ snprintf(buf2, sizeof(buf2), "xxxxxxxxxxxxxxxxxxxx");
+ snprintf(buf3, sizeof(buf3), "xxxxxxxxxxxxxxxxxxxx");
+
+ // Import some new nodes using opts.
+ ImportGraphDefResults results;
+ ExpectOK(
+ R"EOF(
+ node { name: 'new_input' op: 'TestInput' }
+ node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
+ )EOF",
+ opts, &refiner, &results);
+
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("new_input"));
+ EXPECT_TRUE(HasNode("t1"));
+
+ EXPECT_TRUE(HasEdge("input", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("new_input", 1, "t1", 1));
+
+ ASSERT_EQ(results.return_tensors.size(), 1);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "t1");
+}
+
TEST_F(GraphConstructorTest, CopyGraph) {
const int v = TF_GRAPH_DEF_VERSION;
const int bad = v + 17;
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index 8af1936d64..80c76df255 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -22,6 +22,11 @@ limitations under the License.
namespace tensorflow {
+TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {}
+
+SafeTensorId::SafeTensorId(const TensorId& id)
+ : SafeTensorId(id.first.ToString(), id.second) {}
+
TensorId ParseTensorName(const string& name) {
return ParseTensorName(StringPiece(name.data(), name.size()));
}
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
index c27120f7e6..bf13fc78a6 100644
--- a/tensorflow/core/graph/tensor_id.h
+++ b/tensorflow/core/graph/tensor_id.h
@@ -25,6 +25,8 @@ limitations under the License.
namespace tensorflow {
+struct SafeTensorId;
+
// Identifier for a tensor within a step.
// first == operation_name, second == output_index
// Note: does not own backing storage for name.
@@ -34,6 +36,11 @@ struct TensorId : public std::pair<StringPiece, int> {
// Inherit the set of constructors.
using Base::pair;
+ // NOTE(skyewm): this is required on some platforms. I'm not sure why the
+ // using statement above isn't always sufficient.
+ TensorId() : Base() {}
+ TensorId(const SafeTensorId& id);
+
string ToString() const {
if (second == Graph::kControlSlot) return strings::StrCat("^", first);
return strings::StrCat(first, ":", second);
@@ -50,6 +57,32 @@ struct TensorId : public std::pair<StringPiece, int> {
TensorId ParseTensorName(const string& name);
TensorId ParseTensorName(StringPiece name);
+// Same as TensorId, except owns the backing storage for the op name. This makes
+// the memory management simpler at the expense of a copy.
+struct SafeTensorId : public std::pair<string, int> {
+ typedef std::pair<string, int> Base;
+
+ // Inherit the set of constructors.
+ using Base::pair;
+
+ // NOTE(skyewm): this is required on some platforms. I'm not sure why the
+ // using statement above isn't always sufficient.
+ SafeTensorId() : Base() {}
+ SafeTensorId(const TensorId& id);
+
+ string ToString() const {
+ if (second == Graph::kControlSlot) return strings::StrCat("^", first);
+ return strings::StrCat(first, ":", second);
+ }
+
+ struct Hasher {
+ public:
+ std::size_t operator()(const TensorId& x) const {
+ return Hash32(x.first.data(), x.first.size(), x.second);
+ }
+ };
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_TENSOR_ID_H_