aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/tools
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-01-14 03:38:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-14 03:42:18 -0800
commita79e97be8460ce3e1a7de2ddbc78b76151e0035a (patch)
treef376a443273aa53b1301588ad2363a0e3e6c7134 /tensorflow/cc/tools
parent9c94a3f9370f535ccaa705403c60da67dd473bea (diff)
FreezeSavedModel function: Get a frozen GraphDef, inputs, and outputs from a loaded SaveModelBundle.
#14567 PiperOrigin-RevId: 181887870
Diffstat (limited to 'tensorflow/cc/tools')
-rw-r--r--tensorflow/cc/tools/BUILD58
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.cc194
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.h43
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc307
4 files changed, 602 insertions, 0 deletions
diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD
new file mode 100644
index 0000000000..0a7c37383f
--- /dev/null
+++ b/tensorflow/cc/tools/BUILD
@@ -0,0 +1,58 @@
+# Description:
+# TensorFlow cc tools.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "freeze_saved_model",
+ srcs = ["freeze_saved_model.cc"],
+ hdrs = ["freeze_saved_model.h"],
+ deps = [
+ "//tensorflow/cc/saved_model:loader",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+)
+
+tf_cc_test(
+ name = "freeze_saved_model_test",
+ srcs = ["freeze_saved_model_test.cc"],
+ deps = [
+ ":freeze_saved_model",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets.
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
new file mode 100644
index 0000000000..ddf372cdef
--- /dev/null
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -0,0 +1,194 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+
+#include <queue>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Gets tensor names from tensor_info and inserts them into the set of tensor
+// names.
+void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
+ std::unordered_set<string>* tensor_names) {
+ if (tensor_info.has_coo_sparse()) {
+ // If the tensor is sparse we have to add all three tensors of the sparse
+ // representations.
+ const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
+ tensor_names->insert(coo_sparse.values_tensor_name());
+ tensor_names->insert(coo_sparse.indices_tensor_name());
+ tensor_names->insert(coo_sparse.dense_shape_tensor_name());
+ } else {
+ tensor_names->insert(tensor_info.name());
+ }
+}
+
+// Gets the union of all inputs and outputs of all SignatureDefs in the bundle
+void GetSignatureDefsInputsAndOutputs(
+ const SavedModelBundle& saved_model_bundle,
+ std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
+ for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
+ const SignatureDef& signature_def = sigdef_elem.second;
+ for (auto& input_elem : signature_def.inputs()) {
+ GetTensorNamesFromTensorInfo(input_elem.second, inputs);
+ }
+ for (auto& output_elem : signature_def.outputs()) {
+ GetTensorNamesFromTensorInfo(output_elem.second, outputs);
+ }
+ }
+}
+
+// Gets a map from string node name to NodeDef.
+void GetNodeNameToNodeDefMap(
+ GraphDef* graph_def,
+ std::unordered_map<string, NodeDef*>* name_to_node_map) {
+ for (size_t i = 0; i < graph_def->node_size(); i++) {
+ NodeDef* node = graph_def->mutable_node(i);
+ (*name_to_node_map)[node->name()] = node;
+ }
+}
+
+// Gets the set of node names needed by `outputs` and the corresponding set of
+// variable nodes to convert.
+void GetReachableNodesAndVariables(
+ GraphDef* graph_def, const std::unordered_set<string>& outputs,
+ std::unordered_set<string>* reachable_node_names,
+ std::unordered_set<string>* variable_node_names) {
+ // TODO(suharshs): Add support for ResourceVariables.
+ static const std::unordered_set<string>* kVariableTypes =
+ new std::unordered_set<string>({"Variable", "VariableV2"});
+ // name_to_node_map is needed to get the inputs from the NodeDef corresponding
+ // the a string node name. These inputs are used when doing our backwards
+ // traversal.
+ std::unordered_map<string, NodeDef*> name_to_node_map;
+ GetNodeNameToNodeDefMap(graph_def, &name_to_node_map);
+ std::queue<string> nodes_to_visit;
+ for (const string& tensor_name : outputs) {
+ // We need to strip off the tensor part to get the node name.
+ std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
+ nodes_to_visit.push(tensor_name_parts[0]);
+ }
+ // We do a traversal backwards from the outputs specified in the MetaGraphDef.
+ while (!nodes_to_visit.empty()) {
+ const string node_name = nodes_to_visit.front();
+ nodes_to_visit.pop();
+ if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
+ continue;
+ }
+ reachable_node_names->insert(node_name);
+ NodeDef* node = name_to_node_map[node_name];
+ if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
+ variable_node_names->insert(node->name());
+ }
+ for (const string& input : node->input()) {
+ nodes_to_visit.push(input);
+ }
+ }
+}
+
+// Gets a map from variable name to variable value.
+Status GetVariableNameToTensorMap(
+ Session* session, std::unordered_set<string> variable_names_set,
+ std::unordered_map<string, Tensor>* variable_name_to_value_map) {
+ if (variable_names_set.empty()) {
+ return Status::OK();
+ }
+ std::vector<string> variable_names;
+ std::vector<string> tensor_names;
+ for (const string& node_name : variable_names_set) {
+ variable_names.push_back(node_name);
+ // We need to run tensors, so append ":0".
+ tensor_names.push_back(node_name + ":0");
+ }
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(
+ session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
+ for (size_t i = 0; i < variable_names.size(); i++) {
+ (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
+ }
+ return Status::OK();
+}
+
+// Converts a Variable NodeDef into a Constant NodeDef.
+void ConvertVariableToConstant(const NodeDef& variable_node,
+ const Tensor& variable_value,
+ NodeDef* const_node) {
+ const_node->set_name(variable_node.name());
+ const_node->set_op("Const");
+ (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
+ variable_value.AsProtoTensorContent(
+ (*const_node->mutable_attr())["value"].mutable_tensor());
+}
+
+// Freezes the subgraph of all nodes needed by `outputs`.
+Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
+ const std::unordered_set<string>& outputs,
+ GraphDef* frozen_graph_def) {
+ GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
+ // Copy versions and library as-is from original graph.
+ *frozen_graph_def->mutable_versions() = graph_def.versions();
+ *frozen_graph_def->mutable_library() = graph_def.library();
+ // If the graph is empty there is nothing left to do.
+ if (graph_def.node_size() == 0) {
+ return Status::OK();
+ }
+ std::unordered_set<string> reachable_node_names;
+ std::unordered_set<string> variable_node_names;
+ GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names,
+ &variable_node_names);
+ std::unordered_map<string, Tensor> variable_to_value_map;
+ TF_RETURN_IF_ERROR(
+ GetVariableNameToTensorMap(saved_model_bundle.session.get(),
+ variable_node_names, &variable_to_value_map));
+ // We copy the nodes in the same order they were in the original graph_def.
+ for (const NodeDef& node : graph_def.node()) {
+ if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
+ continue;
+ }
+ if (variable_node_names.find(node.name()) != variable_node_names.end()) {
+ ConvertVariableToConstant(node, variable_to_value_map[node.name()],
+ frozen_graph_def->add_node());
+ } else {
+ // If the node isn't a variable, just copy the node as-is.
+ *frozen_graph_def->add_node() = node;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
+ GraphDef* frozen_graph_def,
+ std::unordered_set<string>* inputs,
+ std::unordered_set<string>* outputs) {
+ GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
+ TF_RETURN_IF_ERROR(
+ FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h
new file mode 100644
index 0000000000..bd5e0516c8
--- /dev/null
+++ b/tensorflow/cc/tools/freeze_saved_model.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+
+#include <unordered_set>
+
+#include "tensorflow/cc/saved_model/loader.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Returns a frozen GraphDef, input tensors, and output tensors from the loaded
+// SavedModelBundle.
+// `inputs` and `outputs` consist of the union of all inputs and outputs in the
+// SignatureDefs in the SavedModelBundle.
+// FreezeSavedModel sets `frozen_graph_def` to a GraphDef of all nodes needed by
+// `outputs`. All variables in the supplied SavedModelBundle are converted to
+// constants, set to the value of the variables, by running the restored Session
+// in the SavedModelBundle.
+// WARNING: Only the variable checkpoints will be reflected in the frozen
+// graph_def. All saved_model assets will be ignored.
+Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
+ GraphDef* frozen_graph_def,
+ std::unordered_set<string>* inputs,
+ std::unordered_set<string>* outputs);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
new file mode 100644
index 0000000000..57244a4f0a
--- /dev/null
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -0,0 +1,307 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+class FreezeTest : public ::testing::Test {
+ protected:
+ void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) {
+ EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString());
+ }
+
+ // Builds a SignatureDef with the provided `inputs` and `outputs`.
+ SignatureDef BuildSignatureDef(const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs) {
+ SignatureDef signature_def;
+ for (const string& input : inputs) {
+ (*signature_def.mutable_inputs())[input].set_name(input);
+ }
+ for (const string& output : outputs) {
+ (*signature_def.mutable_outputs())[output].set_name(output);
+ }
+ return signature_def;
+ }
+
+ // Adds `signature_def` to `saved_model_bundle` under `key`.
+ void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def,
+ const string& key,
+ SavedModelBundle* saved_model_bundle) {
+ MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
+ (*meta_graph_def->mutable_signature_def())[key] = signature_def;
+ }
+
+ // Adds an initialized session to `saved_model_bundle` using `graph_def` and
+ // initializing with `init_node`.
+ Status InitializeSavedModelBundleSession(
+ const GraphDef& graph_def, const string& init_node,
+ SavedModelBundle* saved_model_bundle) {
+ SessionOptions session_options;
+ saved_model_bundle->session.reset(NewSession(session_options));
+ TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def));
+ if (!init_node.empty()) {
+ std::vector<Tensor> outputs;
+ return saved_model_bundle->session->Run(
+ /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs);
+ }
+ return Status::OK();
+ }
+
+ // Adds `graph_def` to `saved_model_bundle` and intializes a session with
+ // `init_node`.
+ Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def,
+ const string& init_node,
+ SavedModelBundle* saved_model_bundle) {
+ MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
+ *meta_graph_def->mutable_graph_def() = graph_def;
+ return InitializeSavedModelBundleSession(graph_def, init_node,
+ saved_model_bundle);
+ }
+
+ // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in
+ // `saved_model_bundle` and initializes a session with `init_node`.
+ Status AddGraphDefWithOutputsToSavedModelBundle(
+ const GraphDef& graph_def, const std::unordered_set<string>& outputs,
+ const string& init_node, SavedModelBundle* saved_model_bundle) {
+ SignatureDef signature_def =
+ BuildSignatureDef(std::unordered_set<string>(), outputs);
+ AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+ saved_model_bundle);
+ return AddGraphDefToSavedModelBundle(graph_def, init_node,
+ saved_model_bundle);
+ }
+
+ // Runs and compares the outputs of `tensor_name` on both the
+ // `unfrozen_session` and the `frozen_graph_def.
+ void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session,
+ const GraphDef& frozen_graph_def,
+ const string& tensor_name) {
+ std::vector<Tensor> unfrozen_outputs;
+ TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name},
+ /* targets */ {}, &unfrozen_outputs));
+
+ SessionOptions session_options;
+ std::unique_ptr<Session> frozen_session(NewSession(session_options));
+ TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
+ std::vector<Tensor> frozen_outputs;
+ TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name},
+ /* targets */ {}, &frozen_outputs));
+
+ test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]);
+ }
+};
+
+TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) {
+ // Test that inputs and outputs get correctly populated for a single
+ // SignatureDef.
+ SavedModelBundle saved_model_bundle;
+ std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
+ std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
+ SignatureDef signature_def =
+ BuildSignatureDef(expected_inputs, expected_outputs);
+ AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+ &saved_model_bundle);
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+ EXPECT_EQ(expected_inputs, inputs);
+ EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) {
+ // Test that inputs and outputs get correctly merged and populated when
+ // multiple SignatureDefs are provided.
+ SavedModelBundle saved_model_bundle;
+ SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"});
+ SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"});
+ AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0",
+ &saved_model_bundle);
+ AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1",
+ &saved_model_bundle);
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+ std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
+ std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
+ EXPECT_EQ(expected_inputs, inputs);
+ EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, GraphDefVersionsAndLibrary) {
+ // Test that GraphDef versions and library are copied correctly into the
+ // frozen graph.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ graph_def.mutable_versions()->set_producer(1234);
+ graph_def.mutable_versions()->set_min_consumer(1234);
+ *graph_def.mutable_library()->add_function() = test::function::NonZero();
+ TF_ASSERT_OK(
+ AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithNoVariables) {
+ // Test freezing a graph with no variables.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) {
+ // Test freezing a graph with variables that are not needed by the outputs in
+ // the SignatureDef. The resulting graph shouldn't be frozen, but
+ // non-dependent nodes should be pruned.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDef expected_graph_def;
+ Scope expected_scope = Scope::NewRootScope();
+ Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
+ Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
+ Output expected_c =
+ ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
+ TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
+
+ GraphDefEqual(frozen_graph_def, expected_graph_def);
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) {
+ // Test freezing a graph with variables that are needed by outputs in the
+ // SignatureDef. The variables should be frozen.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output c = ops::Mul(scope.WithOpName("c"), a, var);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ // There should be 3 nodes in the resulting graph_def, and none should be
+ // variables.
+ EXPECT_EQ(frozen_graph_def.node_size(), 3);
+ for (const NodeDef& node : frozen_graph_def.node()) {
+ EXPECT_NE(node.op(), "Variable") << node.name();
+ EXPECT_NE(node.op(), "VariableV2") << node.name();
+ }
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) {
+ // Test freezing a graph with some variables that are needed and not needed by
+ // the outputs in the SignatureDef. The resulting graph should only freeze
+ // dependent variables.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output c = ops::Mul(scope.WithOpName("c"), a, var);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ Output var_1 =
+ ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
+ Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ // There should be 3 nodes in the resulting graph_def, and none should be
+ // variables.
+ EXPECT_EQ(frozen_graph_def.node_size(), 3);
+ for (const NodeDef& node : frozen_graph_def.node()) {
+ EXPECT_NE(node.op(), "Variable") << node.name();
+ EXPECT_NE(node.op(), "VariableV2") << node.name();
+ }
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+}
+
+} // namespace
+} // namespace tensorflow