aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/tools/freeze_saved_model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/tools/freeze_saved_model_test.cc')
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc307
1 files changed, 307 insertions, 0 deletions
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
new file mode 100644
index 0000000000..57244a4f0a
--- /dev/null
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -0,0 +1,307 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+class FreezeTest : public ::testing::Test {
+ protected:
+ void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) {
+ EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString());
+ }
+
+ // Builds a SignatureDef with the provided `inputs` and `outputs`.
+ SignatureDef BuildSignatureDef(const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs) {
+ SignatureDef signature_def;
+ for (const string& input : inputs) {
+ (*signature_def.mutable_inputs())[input].set_name(input);
+ }
+ for (const string& output : outputs) {
+ (*signature_def.mutable_outputs())[output].set_name(output);
+ }
+ return signature_def;
+ }
+
+ // Adds `signature_def` to `saved_model_bundle` under `key`.
+ void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def,
+ const string& key,
+ SavedModelBundle* saved_model_bundle) {
+ MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
+ (*meta_graph_def->mutable_signature_def())[key] = signature_def;
+ }
+
+ // Adds an initialized session to `saved_model_bundle` using `graph_def` and
+ // initializing with `init_node`.
+ Status InitializeSavedModelBundleSession(
+ const GraphDef& graph_def, const string& init_node,
+ SavedModelBundle* saved_model_bundle) {
+ SessionOptions session_options;
+ saved_model_bundle->session.reset(NewSession(session_options));
+ TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def));
+ if (!init_node.empty()) {
+ std::vector<Tensor> outputs;
+ return saved_model_bundle->session->Run(
+ /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs);
+ }
+ return Status::OK();
+ }
+
+ // Adds `graph_def` to `saved_model_bundle` and intializes a session with
+ // `init_node`.
+ Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def,
+ const string& init_node,
+ SavedModelBundle* saved_model_bundle) {
+ MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
+ *meta_graph_def->mutable_graph_def() = graph_def;
+ return InitializeSavedModelBundleSession(graph_def, init_node,
+ saved_model_bundle);
+ }
+
+ // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in
+ // `saved_model_bundle` and initializes a session with `init_node`.
+ Status AddGraphDefWithOutputsToSavedModelBundle(
+ const GraphDef& graph_def, const std::unordered_set<string>& outputs,
+ const string& init_node, SavedModelBundle* saved_model_bundle) {
+ SignatureDef signature_def =
+ BuildSignatureDef(std::unordered_set<string>(), outputs);
+ AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+ saved_model_bundle);
+ return AddGraphDefToSavedModelBundle(graph_def, init_node,
+ saved_model_bundle);
+ }
+
+ // Runs and compares the outputs of `tensor_name` on both the
+ // `unfrozen_session` and the `frozen_graph_def.
+ void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session,
+ const GraphDef& frozen_graph_def,
+ const string& tensor_name) {
+ std::vector<Tensor> unfrozen_outputs;
+ TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name},
+ /* targets */ {}, &unfrozen_outputs));
+
+ SessionOptions session_options;
+ std::unique_ptr<Session> frozen_session(NewSession(session_options));
+ TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
+ std::vector<Tensor> frozen_outputs;
+ TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name},
+ /* targets */ {}, &frozen_outputs));
+
+ test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]);
+ }
+};
+
+TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) {
+ // Test that inputs and outputs get correctly populated for a single
+ // SignatureDef.
+ SavedModelBundle saved_model_bundle;
+ std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
+ std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
+ SignatureDef signature_def =
+ BuildSignatureDef(expected_inputs, expected_outputs);
+ AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+ &saved_model_bundle);
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+ EXPECT_EQ(expected_inputs, inputs);
+ EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) {
+ // Test that inputs and outputs get correctly merged and populated when
+ // multiple SignatureDefs are provided.
+ SavedModelBundle saved_model_bundle;
+ SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"});
+ SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"});
+ AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0",
+ &saved_model_bundle);
+ AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1",
+ &saved_model_bundle);
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+ std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
+ std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
+ EXPECT_EQ(expected_inputs, inputs);
+ EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, GraphDefVersionsAndLibrary) {
+ // Test that GraphDef versions and library are copied correctly into the
+ // frozen graph.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ graph_def.mutable_versions()->set_producer(1234);
+ graph_def.mutable_versions()->set_min_consumer(1234);
+ *graph_def.mutable_library()->add_function() = test::function::NonZero();
+ TF_ASSERT_OK(
+ AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithNoVariables) {
+ // Test freezing a graph with no variables.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) {
+ // Test freezing a graph with variables that are not needed by the outputs in
+ // the SignatureDef. The resulting graph shouldn't be frozen, but
+ // non-dependent nodes should be pruned.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDef expected_graph_def;
+ Scope expected_scope = Scope::NewRootScope();
+ Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
+ Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
+ Output expected_c =
+ ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
+ TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
+
+ GraphDefEqual(frozen_graph_def, expected_graph_def);
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) {
+ // Test freezing a graph with variables that are needed by outputs in the
+ // SignatureDef. The variables should be frozen.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output c = ops::Mul(scope.WithOpName("c"), a, var);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ // There should be 3 nodes in the resulting graph_def, and none should be
+ // variables.
+ EXPECT_EQ(frozen_graph_def.node_size(), 3);
+ for (const NodeDef& node : frozen_graph_def.node()) {
+ EXPECT_NE(node.op(), "Variable") << node.name();
+ EXPECT_NE(node.op(), "VariableV2") << node.name();
+ }
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) {
+ // Test freezing a graph with some variables that are needed and not needed by
+ // the outputs in the SignatureDef. The resulting graph should only freeze
+ // dependent variables.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output c = ops::Mul(scope.WithOpName("c"), a, var);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ Output var_1 =
+ ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
+ Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ // There should be 3 nodes in the resulting graph_def, and none should be
+ // variables.
+ EXPECT_EQ(frozen_graph_def.node_size(), 3);
+ for (const NodeDef& node : frozen_graph_def.node()) {
+ EXPECT_NE(node.op(), "Variable") << node.name();
+ EXPECT_NE(node.op(), "VariableV2") << node.name();
+ }
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+}
+
+} // namespace
+} // namespace tensorflow