aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.cc6
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc25
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
index 2a859d6472..23e9dc40d2 100644
--- a/tensorflow/cc/tools/freeze_saved_model.cc
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/tools/freeze_saved_model.h"
+#include <iostream>
#include <queue>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -72,7 +73,10 @@ void GetNodeNameToNodeDefMap(
}
// Strips off the tensor part of the tensor_name to get the node_name.
-const string GetNodeNameFromTensorName(const string& tensor_name) {
+const string GetNodeNameFromTensorName(string tensor_name) {
+ if (tensor_name[0] == '^') {
+ tensor_name.erase(0, 1);
+ }
std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
return tensor_name_parts[0];
}
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index e265a68e54..979b23c3fc 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -376,6 +376,31 @@ TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
GraphDefEqual(frozen_graph_def, graph_def);
}
+TEST_F(FreezeTest, GraphDefWithControlDependency) {
+ // Inputs that are control dependencies get tensor prefixes,
+ // i.e. ^control_dependency.
+ // Test that we traverse those correctly.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output source = ops::Const(scope.WithOpName("source"), 10.0f, {});
+ Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source),
+ {10.0f, 10.0f}, {2});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
TestFreezeGraphWithoutDependentVariables(false);
}