aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/tools/freeze_saved_model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/tools/freeze_saved_model_test.cc')
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index cd35fd3b95..e265a68e54 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -351,6 +351,31 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) {
GraphDefEqual(frozen_graph_def, graph_def);
}
+TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
+ // Tensors from operations with multiple outputs get tensor suffixes when used
+ // in input fields of following nodes, i.e. split:0, split:1.
+ // Test that we traverse those correctly.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+ OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
TestFreezeGraphWithoutDependentVariables(false);
}