aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2018-03-03 22:57:05 -0800
committerGravatar Jonathan Hseu <jhseu@google.com>2018-03-03 22:57:05 -0800
commit80e33658f2b9461fe5b12ba8043e5b5177d85d17 (patch)
tree4ff0fba0dff8901287182d642a6ea9f44ac5e6af /tensorflow/cc
parentb79ce0029dce3264266ced739590bc238b17096c (diff)
parent0c92f574d18cd01134bb9f7a5a679866a0f92f7e (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/tools/BUILD1
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.cc55
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc268
3 files changed, 211 insertions, 113 deletions
diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD
index 97f66e79b8..f413a5cc52 100644
--- a/tensorflow/cc/tools/BUILD
+++ b/tensorflow/cc/tools/BUILD
@@ -32,6 +32,7 @@ tf_cc_test(
deps = [
":freeze_saved_model",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
index ddf372cdef..4ddddcb586 100644
--- a/tensorflow/cc/tools/freeze_saved_model.cc
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -75,16 +75,13 @@ void GetNodeNameToNodeDefMap(
// variable nodes to convert.
void GetReachableNodesAndVariables(
GraphDef* graph_def, const std::unordered_set<string>& outputs,
+ const std::unordered_map<string, NodeDef*>& name_to_node_map,
std::unordered_set<string>* reachable_node_names,
std::unordered_set<string>* variable_node_names) {
// TODO(suharshs): Add support for ResourceVariables.
static const std::unordered_set<string>* kVariableTypes =
- new std::unordered_set<string>({"Variable", "VariableV2"});
- // name_to_node_map is needed to get the inputs from the NodeDef corresponding
- // the a string node name. These inputs are used when doing our backwards
- // traversal.
- std::unordered_map<string, NodeDef*> name_to_node_map;
- GetNodeNameToNodeDefMap(graph_def, &name_to_node_map);
+ new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
+
std::queue<string> nodes_to_visit;
for (const string& tensor_name : outputs) {
// We need to strip off the tensor part to get the node name.
@@ -99,7 +96,7 @@ void GetReachableNodesAndVariables(
continue;
}
reachable_node_names->insert(node_name);
- NodeDef* node = name_to_node_map[node_name];
+ NodeDef* node = name_to_node_map.at(node_name);
if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
variable_node_names->insert(node->name());
}
@@ -111,7 +108,9 @@ void GetReachableNodesAndVariables(
// Gets a map from variable name to variable value.
Status GetVariableNameToTensorMap(
- Session* session, std::unordered_set<string> variable_names_set,
+ Session* session,
+ const std::unordered_map<string, NodeDef*>& name_to_node_map,
+ std::unordered_set<string> variable_names_set,
std::unordered_map<string, Tensor>* variable_name_to_value_map) {
if (variable_names_set.empty()) {
return Status::OK();
@@ -120,8 +119,14 @@ Status GetVariableNameToTensorMap(
std::vector<string> tensor_names;
for (const string& node_name : variable_names_set) {
variable_names.push_back(node_name);
- // We need to run tensors, so append ":0".
- tensor_names.push_back(node_name + ":0");
+ NodeDef* node_def = name_to_node_map.at(node_name);
+ if (node_def->op() == "VarHandleOp") {
+ // If this is a resource variable, we have to run the corresponding
+ // ReadVariableOp.
+ tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
+ } else {
+ tensor_names.push_back(node_name + ":0");
+ }
}
std::vector<Tensor> outputs;
TF_RETURN_IF_ERROR(
@@ -143,6 +148,15 @@ void ConvertVariableToConstant(const NodeDef& variable_node,
(*const_node->mutable_attr())["value"].mutable_tensor());
}
+// Converts a ReadVariableOp NodeDef to an Identity NodeDef.
+void ConvertReadVariableOpToIdentity(const NodeDef& node,
+ NodeDef* identity_node) {
+ identity_node->set_name(node.name());
+ identity_node->set_op("Identity");
+ (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
+ identity_node->add_input(node.input(0));
+}
+
// Freezes the subgraph of all nodes needed by `outputs`.
Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
const std::unordered_set<string>& outputs,
@@ -155,14 +169,19 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
if (graph_def.node_size() == 0) {
return Status::OK();
}
+ // name_to_node_map is needed to get the inputs from the NodeDef corresponding
+ // the a string node name. These inputs are used when doing our backwards
+ // traversal.
+ std::unordered_map<string, NodeDef*> name_to_node_map;
+ GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
std::unordered_set<string> reachable_node_names;
std::unordered_set<string> variable_node_names;
- GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names,
- &variable_node_names);
+ GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
+ &reachable_node_names, &variable_node_names);
std::unordered_map<string, Tensor> variable_to_value_map;
- TF_RETURN_IF_ERROR(
- GetVariableNameToTensorMap(saved_model_bundle.session.get(),
- variable_node_names, &variable_to_value_map));
+ TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
+ saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
+ &variable_to_value_map));
// We copy the nodes in the same order they were in the original graph_def.
for (const NodeDef& node : graph_def.node()) {
if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
@@ -171,6 +190,12 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
if (variable_node_names.find(node.name()) != variable_node_names.end()) {
ConvertVariableToConstant(node, variable_to_value_map[node.name()],
frozen_graph_def->add_node());
+ } else if (node.op() == "ReadVariableOp" &&
+ variable_node_names.find(node.input(0)) !=
+ variable_node_names.end()) {
+ // If the node is a ReadVariableOp, its input VarHandleOp will be
+ // converted to a Constant, so we will need to convert it to an Identity.
+ ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
} else {
// If the node isn't a variable, just copy the node as-is.
*frozen_graph_def->add_node() = node;
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index 52a81a5028..cd35fd3b95 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/tools/freeze_saved_model.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -113,6 +114,160 @@ class FreezeTest : public ::testing::Test {
test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]);
}
+
+ void TestFreezeGraphWithoutDependentVariables(bool use_resource) {
+ // Test freezing a graph with variables that are not needed by the outputs
+ // in the SignatureDef. The resulting graph shouldn't be frozen, but
+ // non-dependent nodes should be pruned.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ if (use_resource) {
+ Output var =
+ ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
+ Output read_var = ops::ReadVariableOp(
+ scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
+ auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
+ } else {
+ Output var =
+ ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+ }
+
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, "assign", &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
+ &inputs, &outputs));
+
+ GraphDef expected_graph_def;
+ Scope expected_scope = Scope::NewRootScope();
+ Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
+ Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
+ Output expected_c =
+ ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
+ TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
+
+ GraphDefEqual(frozen_graph_def, expected_graph_def);
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+ }
+
+ void TestFreezeGraphWithDependentVariables(bool use_resource) {
+ // Test freezing a graph with variables that are needed by outputs in the
+ // SignatureDef. The variables should be frozen.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output read_var;
+ if (use_resource) {
+ Output var =
+ ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
+ read_var = ops::ReadVariableOp(
+ scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
+ auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
+ } else {
+ Output read_var =
+ ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a);
+ }
+ Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, "assign", &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
+ &inputs, &outputs));
+
+ // If using normal variables there should be 3 nodes in the resulting
+ // graph_def. If using resource variables there should be 4 nodes in the
+ // resulting graph_def.
+ // In both cases, none should be variables.
+ size_t expected_nodes = use_resource ? 4 : 3;
+ EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes);
+ for (const NodeDef& node : frozen_graph_def.node()) {
+ EXPECT_NE(node.op(), "Variable") << node.name();
+ EXPECT_NE(node.op(), "VariableV2") << node.name();
+ EXPECT_NE(node.op(), "VarHandleOp") << node.name();
+ EXPECT_NE(node.op(), "ReadVariableOp") << node.name();
+ }
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+ }
+
+ void TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource) {
+ // Test freezing a graph with some variables that are needed and not needed
+ // by
+ // the outputs in the SignatureDef. The resulting graph should only freeze
+ // dependent variables.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+ Output read_var;
+
+ if (use_resource) {
+ Output var =
+ ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
+ read_var = ops::ReadVariableOp(
+ scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
+ auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
+ Output var_1 =
+ ops::VarHandleOp(scope.WithOpName("var_1"), DataType::DT_FLOAT, {});
+ Output read_var_1 =
+ ops::ReadVariableOp(scope.WithOpName("var_1/Read/ReadVariableOp"),
+ var, DataType::DT_FLOAT);
+ auto assign_1 =
+ ops::AssignVariableOp(scope.WithOpName("assign_1"), var_1, a);
+ } else {
+ read_var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+ Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a);
+ Output var_1 =
+ ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
+ Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var_1, a);
+ }
+
+ Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ // "c" isnt dependent on the variable, so nothing should be frozen.
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+ graph_def, {"c:0"}, "assign", &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
+ &inputs, &outputs));
+
+ // There should be 3 nodes in the resulting graph_def, and none should be
+ // variables.
+ size_t expected_nodes = use_resource ? 4 : 3;
+ EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes);
+ for (const NodeDef& node : frozen_graph_def.node()) {
+ EXPECT_NE(node.op(), "Variable") << node.name();
+ EXPECT_NE(node.op(), "VariableV2") << node.name();
+ EXPECT_NE(node.op(), "VarHandleOp") << node.name();
+ EXPECT_NE(node.op(), "ReadVariableOp") << node.name();
+ }
+
+ RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+ frozen_graph_def, "c:0");
+ }
};
TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) {
@@ -196,111 +351,28 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) {
GraphDefEqual(frozen_graph_def, graph_def);
}
-TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) {
- // Test freezing a graph with variables that are not needed by the outputs in
- // the SignatureDef. The resulting graph shouldn't be frozen, but
- // non-dependent nodes should be pruned.
- SavedModelBundle saved_model_bundle;
- GraphDef graph_def;
- Scope scope = Scope::NewRootScope();
- Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
- Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
- Output c = ops::Mul(scope.WithOpName("c"), a, b);
- Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
- Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
- TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
- // "c" isnt dependent on the variable, so nothing should be frozen.
- TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
- graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
-
- GraphDef frozen_graph_def;
- std::unordered_set<string> inputs;
- std::unordered_set<string> outputs;
- TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
- &outputs));
-
- GraphDef expected_graph_def;
- Scope expected_scope = Scope::NewRootScope();
- Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
- Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
- Output expected_c =
- ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
- TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
-
- GraphDefEqual(frozen_graph_def, expected_graph_def);
-
- RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
- frozen_graph_def, "c:0");
+TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
+ TestFreezeGraphWithoutDependentVariables(false);
}
-TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) {
- // Test freezing a graph with variables that are needed by outputs in the
- // SignatureDef. The variables should be frozen.
- SavedModelBundle saved_model_bundle;
- GraphDef graph_def;
- Scope scope = Scope::NewRootScope();
- Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
- Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
- Output c = ops::Mul(scope.WithOpName("c"), a, var);
- Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
- TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
- // "c" isnt dependent on the variable, so nothing should be frozen.
- TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
- graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
-
- GraphDef frozen_graph_def;
- std::unordered_set<string> inputs;
- std::unordered_set<string> outputs;
- TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
- &outputs));
-
- // There should be 3 nodes in the resulting graph_def, and none should be
- // variables.
- EXPECT_EQ(frozen_graph_def.node_size(), 3);
- for (const NodeDef& node : frozen_graph_def.node()) {
- EXPECT_NE(node.op(), "Variable") << node.name();
- EXPECT_NE(node.op(), "VariableV2") << node.name();
- }
-
- RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
- frozen_graph_def, "c:0");
+TEST_F(FreezeTest, GraphDefWithoutDependentResourceVariables) {
+ TestFreezeGraphWithoutDependentVariables(true);
}
-TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) {
- // Test freezing a graph with some variables that are needed and not needed by
- // the outputs in the SignatureDef. The resulting graph should only freeze
- // dependent variables.
- SavedModelBundle saved_model_bundle;
- GraphDef graph_def;
- Scope scope = Scope::NewRootScope();
- Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
- Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
- Output c = ops::Mul(scope.WithOpName("c"), a, var);
- Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
- Output var_1 =
- ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
- Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a);
- TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
- // "c" isnt dependent on the variable, so nothing should be frozen.
- TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
- graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+TEST_F(FreezeTest, GraphDefWithDependentVariables) {
+ TestFreezeGraphWithDependentVariables(false);
+}
- GraphDef frozen_graph_def;
- std::unordered_set<string> inputs;
- std::unordered_set<string> outputs;
- TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
- &outputs));
+TEST_F(FreezeTest, GraphDefWithDependentResourceVariables) {
+ TestFreezeGraphWithDependentVariables(true);
+}
- // There should be 3 nodes in the resulting graph_def, and none should be
- // variables.
- EXPECT_EQ(frozen_graph_def.node_size(), 3);
- for (const NodeDef& node : frozen_graph_def.node()) {
- EXPECT_NE(node.op(), "Variable") << node.name();
- EXPECT_NE(node.op(), "VariableV2") << node.name();
- }
+TEST_F(FreezeTest, GraphDefWithAndWithoutDependentVariables) {
+ TestFreezeGraphWithAndWithoutDependentVariables(false);
+}
- RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
- frozen_graph_def, "c:0");
+TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) {
+ TestFreezeGraphWithAndWithoutDependentVariables(true);
}
} // namespace