diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-01-14 03:38:40 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-14 03:42:18 -0800 |
commit | a79e97be8460ce3e1a7de2ddbc78b76151e0035a (patch) | |
tree | f376a443273aa53b1301588ad2363a0e3e6c7134 /tensorflow/cc/tools | |
parent | 9c94a3f9370f535ccaa705403c60da67dd473bea (diff) |
FreezeSavedModel function: Get a frozen GraphDef, inputs, and outputs from a loaded SaveModelBundle.
#14567
PiperOrigin-RevId: 181887870
Diffstat (limited to 'tensorflow/cc/tools')
-rw-r--r-- | tensorflow/cc/tools/BUILD | 58 | ||||
-rw-r--r-- | tensorflow/cc/tools/freeze_saved_model.cc | 194 | ||||
-rw-r--r-- | tensorflow/cc/tools/freeze_saved_model.h | 43 | ||||
-rw-r--r-- | tensorflow/cc/tools/freeze_saved_model_test.cc | 307 |
4 files changed, 602 insertions, 0 deletions
diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD new file mode 100644 index 0000000000..0a7c37383f --- /dev/null +++ b/tensorflow/cc/tools/BUILD @@ -0,0 +1,58 @@ +# Description: +# TensorFlow cc tools. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "freeze_saved_model", + srcs = ["freeze_saved_model.cc"], + hdrs = ["freeze_saved_model.h"], + deps = [ + "//tensorflow/cc/saved_model:loader", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +tf_cc_test( + name = "freeze_saved_model_test", + srcs = ["freeze_saved_model_test.cc"], + deps = [ + ":freeze_saved_model", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc new file mode 100644 index 0000000000..ddf372cdef --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/tools/freeze_saved_model.h" + +#include <queue> + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +namespace { + +// Gets tensor names from tensor_info and inserts them into the set of tensor +// names. +void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info, + std::unordered_set<string>* tensor_names) { + if (tensor_info.has_coo_sparse()) { + // If the tensor is sparse we have to add all three tensors of the sparse + // representations. + const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse(); + tensor_names->insert(coo_sparse.values_tensor_name()); + tensor_names->insert(coo_sparse.indices_tensor_name()); + tensor_names->insert(coo_sparse.dense_shape_tensor_name()); + } else { + tensor_names->insert(tensor_info.name()); + } +} + +// Gets the union of all inputs and outputs of all SignatureDefs in the bundle +void GetSignatureDefsInputsAndOutputs( + const SavedModelBundle& saved_model_bundle, + std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) { + for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) { + const SignatureDef& signature_def = sigdef_elem.second; + for (auto& input_elem : signature_def.inputs()) { + GetTensorNamesFromTensorInfo(input_elem.second, inputs); + } + for (auto& output_elem : signature_def.outputs()) { + GetTensorNamesFromTensorInfo(output_elem.second, outputs); + } + } +} + +// Gets a map from string node name to NodeDef. +void GetNodeNameToNodeDefMap( + GraphDef* graph_def, + std::unordered_map<string, NodeDef*>* name_to_node_map) { + for (size_t i = 0; i < graph_def->node_size(); i++) { + NodeDef* node = graph_def->mutable_node(i); + (*name_to_node_map)[node->name()] = node; + } +} + +// Gets the set of node names needed by `outputs` and the corresponding set of +// variable nodes to convert. +void GetReachableNodesAndVariables( + GraphDef* graph_def, const std::unordered_set<string>& outputs, + std::unordered_set<string>* reachable_node_names, + std::unordered_set<string>* variable_node_names) { + // TODO(suharshs): Add support for ResourceVariables. + static const std::unordered_set<string>* kVariableTypes = + new std::unordered_set<string>({"Variable", "VariableV2"}); + // name_to_node_map is needed to get the inputs from the NodeDef corresponding + // the a string node name. These inputs are used when doing our backwards + // traversal. + std::unordered_map<string, NodeDef*> name_to_node_map; + GetNodeNameToNodeDefMap(graph_def, &name_to_node_map); + std::queue<string> nodes_to_visit; + for (const string& tensor_name : outputs) { + // We need to strip off the tensor part to get the node name. + std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':'); + nodes_to_visit.push(tensor_name_parts[0]); + } + // We do a traversal backwards from the outputs specified in the MetaGraphDef. + while (!nodes_to_visit.empty()) { + const string node_name = nodes_to_visit.front(); + nodes_to_visit.pop(); + if (reachable_node_names->find(node_name) != reachable_node_names->end()) { + continue; + } + reachable_node_names->insert(node_name); + NodeDef* node = name_to_node_map[node_name]; + if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { + variable_node_names->insert(node->name()); + } + for (const string& input : node->input()) { + nodes_to_visit.push(input); + } + } +} + +// Gets a map from variable name to variable value. +Status GetVariableNameToTensorMap( + Session* session, std::unordered_set<string> variable_names_set, + std::unordered_map<string, Tensor>* variable_name_to_value_map) { + if (variable_names_set.empty()) { + return Status::OK(); + } + std::vector<string> variable_names; + std::vector<string> tensor_names; + for (const string& node_name : variable_names_set) { + variable_names.push_back(node_name); + // We need to run tensors, so append ":0". + tensor_names.push_back(node_name + ":0"); + } + std::vector<Tensor> outputs; + TF_RETURN_IF_ERROR( + session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs)); + for (size_t i = 0; i < variable_names.size(); i++) { + (*variable_name_to_value_map)[variable_names[i]] = outputs[i]; + } + return Status::OK(); +} + +// Converts a Variable NodeDef into a Constant NodeDef. +void ConvertVariableToConstant(const NodeDef& variable_node, + const Tensor& variable_value, + NodeDef* const_node) { + const_node->set_name(variable_node.name()); + const_node->set_op("Const"); + (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype"); + variable_value.AsProtoTensorContent( + (*const_node->mutable_attr())["value"].mutable_tensor()); +} + +// Freezes the subgraph of all nodes needed by `outputs`. +Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, + const std::unordered_set<string>& outputs, + GraphDef* frozen_graph_def) { + GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def(); + // Copy versions and library as-is from original graph. + *frozen_graph_def->mutable_versions() = graph_def.versions(); + *frozen_graph_def->mutable_library() = graph_def.library(); + // If the graph is empty there is nothing left to do. + if (graph_def.node_size() == 0) { + return Status::OK(); + } + std::unordered_set<string> reachable_node_names; + std::unordered_set<string> variable_node_names; + GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names, + &variable_node_names); + std::unordered_map<string, Tensor> variable_to_value_map; + TF_RETURN_IF_ERROR( + GetVariableNameToTensorMap(saved_model_bundle.session.get(), + variable_node_names, &variable_to_value_map)); + // We copy the nodes in the same order they were in the original graph_def. + for (const NodeDef& node : graph_def.node()) { + if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { + continue; + } + if (variable_node_names.find(node.name()) != variable_node_names.end()) { + ConvertVariableToConstant(node, variable_to_value_map[node.name()], + frozen_graph_def->add_node()); + } else { + // If the node isn't a variable, just copy the node as-is. + *frozen_graph_def->add_node() = node; + } + } + return Status::OK(); +} + +} // namespace + +Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set<string>* inputs, + std::unordered_set<string>* outputs) { + GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); + TF_RETURN_IF_ERROR( + FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h new file mode 100644 index 0000000000..bd5e0516c8 --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ + +#include <unordered_set> + +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Returns a frozen GraphDef, input tensors, and output tensors from the loaded +// SavedModelBundle. +// `inputs` and `outputs` consist of the union of all inputs and outputs in the +// SignatureDefs in the SavedModelBundle. +// FreezeSavedModel sets `frozen_graph_def` to a GraphDef of all nodes needed by +// `outputs`. All variables in the supplied SavedModelBundle are converted to +// constants, set to the value of the variables, by running the restored Session +// in the SavedModelBundle. +// WARNING: Only the variable checkpoints will be reflected in the frozen +// graph_def. All saved_model assets will be ignored. +Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set<string>* inputs, + std::unordered_set<string>* outputs); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc new file mode 100644 index 0000000000..57244a4f0a --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/tools/freeze_saved_model.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class FreezeTest : public ::testing::Test { + protected: + void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) { + EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString()); + } + + // Builds a SignatureDef with the provided `inputs` and `outputs`. + SignatureDef BuildSignatureDef(const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs) { + SignatureDef signature_def; + for (const string& input : inputs) { + (*signature_def.mutable_inputs())[input].set_name(input); + } + for (const string& output : outputs) { + (*signature_def.mutable_outputs())[output].set_name(output); + } + return signature_def; + } + + // Adds `signature_def` to `saved_model_bundle` under `key`. + void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def, + const string& key, + SavedModelBundle* saved_model_bundle) { + MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; + (*meta_graph_def->mutable_signature_def())[key] = signature_def; + } + + // Adds an initialized session to `saved_model_bundle` using `graph_def` and + // initializing with `init_node`. + Status InitializeSavedModelBundleSession( + const GraphDef& graph_def, const string& init_node, + SavedModelBundle* saved_model_bundle) { + SessionOptions session_options; + saved_model_bundle->session.reset(NewSession(session_options)); + TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def)); + if (!init_node.empty()) { + std::vector<Tensor> outputs; + return saved_model_bundle->session->Run( + /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs); + } + return Status::OK(); + } + + // Adds `graph_def` to `saved_model_bundle` and intializes a session with + // `init_node`. + Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def, + const string& init_node, + SavedModelBundle* saved_model_bundle) { + MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; + *meta_graph_def->mutable_graph_def() = graph_def; + return InitializeSavedModelBundleSession(graph_def, init_node, + saved_model_bundle); + } + + // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in + // `saved_model_bundle` and initializes a session with `init_node`. + Status AddGraphDefWithOutputsToSavedModelBundle( + const GraphDef& graph_def, const std::unordered_set<string>& outputs, + const string& init_node, SavedModelBundle* saved_model_bundle) { + SignatureDef signature_def = + BuildSignatureDef(std::unordered_set<string>(), outputs); + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + saved_model_bundle); + return AddGraphDefToSavedModelBundle(graph_def, init_node, + saved_model_bundle); + } + + // Runs and compares the outputs of `tensor_name` on both the + // `unfrozen_session` and the `frozen_graph_def. + void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session, + const GraphDef& frozen_graph_def, + const string& tensor_name) { + std::vector<Tensor> unfrozen_outputs; + TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name}, + /* targets */ {}, &unfrozen_outputs)); + + SessionOptions session_options; + std::unique_ptr<Session> frozen_session(NewSession(session_options)); + TF_ASSERT_OK(frozen_session->Create(frozen_graph_def)); + std::vector<Tensor> frozen_outputs; + TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name}, + /* targets */ {}, &frozen_outputs)); + + test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]); + } +}; + +TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) { + // Test that inputs and outputs get correctly populated for a single + // SignatureDef. + SavedModelBundle saved_model_bundle; + std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"}; + std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"}; + SignatureDef signature_def = + BuildSignatureDef(expected_inputs, expected_outputs); + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) { + // Test that inputs and outputs get correctly merged and populated when + // multiple SignatureDefs are provided. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"}); + SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"}); + AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0", + &saved_model_bundle); + AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"}; + std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, GraphDefVersionsAndLibrary) { + // Test that GraphDef versions and library are copied correctly into the + // frozen graph. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + graph_def.mutable_versions()->set_producer(1234); + graph_def.mutable_versions()->set_min_consumer(1234); + *graph_def.mutable_library()->add_function() = test::function::NonZero(); + TF_ASSERT_OK( + AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithNoVariables) { + // Test freezing a graph with no variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { + // Test freezing a graph with variables that are not needed by the outputs in + // the SignatureDef. The resulting graph shouldn't be frozen, but + // non-dependent nodes should be pruned. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDef expected_graph_def; + Scope expected_scope = Scope::NewRootScope(); + Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); + Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); + Output expected_c = + ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); + TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); + + GraphDefEqual(frozen_graph_def, expected_graph_def); + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { + // Test freezing a graph with variables that are needed by outputs in the + // SignatureDef. The variables should be frozen. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output c = ops::Mul(scope.WithOpName("c"), a, var); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + EXPECT_EQ(frozen_graph_def.node_size(), 3); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) { + // Test freezing a graph with some variables that are needed and not needed by + // the outputs in the SignatureDef. The resulting graph should only freeze + // dependent variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output c = ops::Mul(scope.WithOpName("c"), a, var); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + Output var_1 = + ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); + Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + EXPECT_EQ(frozen_graph_def.node_size(), 3); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +} // namespace +} // namespace tensorflow |