aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/tools
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-05-09 17:30:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 17:33:41 -0700
commit20387e460ad8b72cb4ac9f6bda00394f2a404f3f (patch)
tree72fcd7a7d828b2205d1d1e0e2f2d98262d90552f /tensorflow/cc/tools
parent930974af4d8e24958c75286c31dc7e0ee67e75ba (diff)
Fix FreezeSavedModel to handle traversal of operations with multiple outputs.
PiperOrigin-RevId: 196055377
Diffstat (limited to 'tensorflow/cc/tools')
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.cc16
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc25
2 files changed, 35 insertions, 6 deletions
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
index 4ddddcb586..2a859d6472 100644
--- a/tensorflow/cc/tools/freeze_saved_model.cc
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -71,6 +71,12 @@ void GetNodeNameToNodeDefMap(
}
}
+// Strips off the tensor part of the tensor_name to get the node_name.
+const string GetNodeNameFromTensorName(const string& tensor_name) {
+ std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
+ return tensor_name_parts[0];
+}
+
// Gets the set of node names needed by `outputs` and the corresponding set of
// variable nodes to convert.
void GetReachableNodesAndVariables(
@@ -83,10 +89,8 @@ void GetReachableNodesAndVariables(
new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
std::queue<string> nodes_to_visit;
- for (const string& tensor_name : outputs) {
- // We need to strip off the tensor part to get the node name.
- std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
- nodes_to_visit.push(tensor_name_parts[0]);
+ for (const string& output_tensor_name : outputs) {
+ nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
}
// We do a traversal backwards from the outputs specified in the MetaGraphDef.
while (!nodes_to_visit.empty()) {
@@ -100,8 +104,8 @@ void GetReachableNodesAndVariables(
if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
variable_node_names->insert(node->name());
}
- for (const string& input : node->input()) {
- nodes_to_visit.push(input);
+ for (const string& input_tensor_name : node->input()) {
+ nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
}
}
}
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index cd35fd3b95..e265a68e54 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -351,6 +351,31 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) {
GraphDefEqual(frozen_graph_def, graph_def);
}
+TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
+ // Tensors from operations with multiple outputs get tensor suffixes when used
+ // in input fields of following nodes, i.e. split:0, split:1.
+ // Test that we traverse those correctly.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+ OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
TestFreezeGraphWithoutDependentVariables(false);
}