diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-05-10 10:58:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 11:00:43 -0700 |
commit | 0013b6953547fe17865c21155bdebe4cfe656e74 (patch) | |
tree | 665b9bc4e3d9eeb7650117e13dc3c198132b6174 /tensorflow/cc | |
parent | af4cd0e87cf59c5307546a9ca41bdd457634c58d (diff) |
Traverse through control dependencies.
PiperOrigin-RevId: 196139886
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/tools/freeze_saved_model.cc | 6 | ||||
-rw-r--r-- | tensorflow/cc/tools/freeze_saved_model_test.cc | 25 |
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 2a859d6472..23e9dc40d2 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include <iostream> #include <queue> #include "tensorflow/core/framework/attr_value.pb.h" @@ -72,7 +73,10 @@ void GetNodeNameToNodeDefMap( } // Strips off the tensor part of the tensor_name to get the node_name. -const string GetNodeNameFromTensorName(const string& tensor_name) { +const string GetNodeNameFromTensorName(string tensor_name) { + if (tensor_name[0] == '^') { + tensor_name.erase(0, 1); + } std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':'); return tensor_name_parts[0]; } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index e265a68e54..979b23c3fc 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -376,6 +376,31 @@ TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) { GraphDefEqual(frozen_graph_def, graph_def); } +TEST_F(FreezeTest, GraphDefWithControlDependency) { + // Inputs that are control dependencies get tensor prefixes, + // i.e. ^control_dependency. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output source = ops::Const(scope.WithOpName("source"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source), + {10.0f, 10.0f}, {2}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { TestFreezeGraphWithoutDependentVariables(false); } |