aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-05-10 10:58:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 11:00:43 -0700
commit0013b6953547fe17865c21155bdebe4cfe656e74 (patch)
tree665b9bc4e3d9eeb7650117e13dc3c198132b6174 /tensorflow/cc
parentaf4cd0e87cf59c5307546a9ca41bdd457634c58d (diff)
Traverse through control dependencies.
PiperOrigin-RevId: 196139886
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.cc6
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc25
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
index 2a859d6472..23e9dc40d2 100644
--- a/tensorflow/cc/tools/freeze_saved_model.cc
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/tools/freeze_saved_model.h"
+#include <iostream>
#include <queue>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -72,7 +73,10 @@ void GetNodeNameToNodeDefMap(
}
// Strips off the tensor part of the tensor_name to get the node_name.
-const string GetNodeNameFromTensorName(const string& tensor_name) {
+const string GetNodeNameFromTensorName(string tensor_name) {
+ if (tensor_name[0] == '^') {
+ tensor_name.erase(0, 1);
+ }
std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
return tensor_name_parts[0];
}
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index e265a68e54..979b23c3fc 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -376,6 +376,31 @@ TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
GraphDefEqual(frozen_graph_def, graph_def);
}
+TEST_F(FreezeTest, GraphDefWithControlDependency) {
+ // Inputs that are control dependencies get tensor prefixes,
+ // i.e. ^control_dependency.
+ // Test that we traverse those correctly.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output source = ops::Const(scope.WithOpName("source"), 10.0f, {});
+ Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source),
+ {10.0f, 10.0f}, {2});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
TestFreezeGraphWithoutDependentVariables(false);
}