aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-01 19:50:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 21:44:13 -0800
commitc39d141948174b94213848d9b95541ee09af5e53 (patch)
treee6490295482231b961e69d06717cb03a9c90a470 /tensorflow/tools/graph_transforms
parent930b347273c1b9ac32b7c96a8d35e92704c4e8b7 (diff)
Supporting new saving op structure.
PiperOrigin-RevId: 184233513
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc53
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather_test.cc63
2 files changed, 68 insertions, 48 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 9c583d83ca..214ec721e2 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -86,8 +86,17 @@ void CreateConstNode(const Tensor& tensor, const string& name,
SetNodeTensorAttr<float>("value", tensor, node_def);
}
+string GetMonolithicTensorKey(const string& tensor_slice_name) {
+ std::vector<string> names = Split(tensor_slice_name, "/");
+ if (StringPiece(names[names.size() - 1]).starts_with("part_")) {
+ CHECK_GE(names.size(), 2);
+ names.pop_back();
+ }
+ return Join(names, "/");
+}
+
Status ObtainTensorSlice(const GraphDef& input_graph_def,
- const string& tensor_name,
+ const string& target_name,
string* shape_slice_string) {
string restore_node_name;
for (const auto& node : input_graph_def.node()) {
@@ -95,39 +104,53 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
if (node_name_parts.size() == 2 &&
StringPiece(node_name_parts[0]).starts_with("save") &&
StringPiece(node_name_parts[1]).starts_with("Assign") &&
- node.input(0) == tensor_name) {
+ node.input(0) == target_name) {
restore_node_name = node.input(1);
break;
}
}
+
+ std::vector<string> restore_node_parts = Split(restore_node_name, ":");
+ CHECK_LE(restore_node_parts.size(), 2);
+ string tensor_names_node;
string shape_and_slices_node;
for (const auto& node : input_graph_def.node()) {
- if ((node.name() == restore_node_name) && (node.op() == "RestoreV2")) {
+ if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
+ tensor_names_node = node.input(1);
shape_and_slices_node = node.input(2);
break;
}
}
+
+ int offset = -1;
+ for (const auto& node : input_graph_def.node()) {
+ if (node.name() == tensor_names_node) {
+ Tensor tensor_names_tensor;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
+ const auto& tensor_names_value = tensor_names_tensor.flat<string>();
+ for (int i = 0; i < tensor_names_value.size(); i++) {
+ if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
+ offset = i;
+ break;
+ }
+ }
+ }
+ }
+ if (offset == -1) {
+ return errors::Internal("Unable to find RestoreV2 entry for variable: ",
+ target_name);
+ }
for (const auto& node : input_graph_def.node()) {
if (node.name() == shape_and_slices_node) {
Tensor shape_and_slices_tensor;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
const auto& shape_and_slices_value =
shape_and_slices_tensor.flat<string>();
- *shape_slice_string = shape_and_slices_value(0);
+ *shape_slice_string = shape_and_slices_value(offset);
return Status::OK();
}
}
- return errors::Internal("Unable to find slice for variable: ", tensor_name);
-}
-
-string GetMonolithicTensorKey(const string& tensor_slice_name) {
- std::vector<string> names = Split(tensor_slice_name, "/");
- CHECK_GE(names.size(), 2);
- CHECK(StringPiece(names[names.size() - 1]).starts_with("part_"));
-
- // Remove the "part_x" suffix
- names.pop_back();
- return Join(names, "/");
+ return errors::Internal("Unable to find slice for variable: ", target_name);
}
Status ReadTensorFromCheckpoint(
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index 203ed3e0f9..d41321c9a6 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -106,11 +106,15 @@ class SparsifyGatherTest : public ::testing::Test {
NodeDef* save_const_node =
CreateNode("save/Const", "Const", {}, &graph_def);
+ Tensor tensor_names_values(DT_STRING, TensorShape({1}));
+ test::FillValues<string>(&tensor_names_values, {"w"});
NodeDef* tensor_names_node =
CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
+ SetNodeTensorAttr<string>("value", tensor_names_values,
+ tensor_names_node);
+
NodeDef* tensor_shapes_slices_node = CreateNode(
"save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
-
Tensor shapes_slices_val(DT_STRING, TensorShape({1}));
shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
SetNodeTensorAttr<string>("value", shapes_slices_val,
@@ -310,6 +314,29 @@ class SparsifyGatherTest : public ::testing::Test {
SetNodeTensorAttr<float>("value", weights, w_node1);
SetNodeTensorAttr<float>("value", weights, w_node2);
} else {
+ NodeDef* save_const_node =
+ CreateNode("save/Const", "Const", {}, &graph_def);
+
+ NodeDef* tensor_names_node =
+ CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
+ Tensor tensor_names_values(DT_STRING, TensorShape({2}));
+ test::FillValues<string>(&tensor_names_values, {"w1", "w2"});
+ SetNodeTensorAttr<string>("value", tensor_names_values,
+ tensor_names_node);
+
+ NodeDef* tensor_shapes_slices_node = CreateNode(
+ "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
+ Tensor shapes_slices_val(DT_STRING, TensorShape({2}));
+ shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
+ shapes_slices_val.flat<string>()(1) = "4 1 0,4:0,1";
+ SetNodeTensorAttr<string>("value", shapes_slices_val,
+ tensor_shapes_slices_node);
+
+ NodeDef* restore_node = CreateNode(
+ "save/RestoreV2", "RestoreV2",
+ {save_const_node, tensor_names_node, tensor_shapes_slices_node},
+ &graph_def);
+
w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def);
zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor",
@@ -321,23 +348,7 @@ class SparsifyGatherTest : public ::testing::Test {
assign_node1 = CreateNode("w1/part_1/Assign", "Assign",
{w_node1, zeros_node1}, &graph_def);
- NodeDef* save_const_node =
- CreateNode("save/Const", "Const", {}, &graph_def);
- NodeDef* tensor_names_node1 =
- CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
- NodeDef* tensor_shapes_slices_node1 = CreateNode(
- "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
-
- Tensor shapes_slices_val1(DT_STRING, TensorShape({1}));
- shapes_slices_val1.flat<string>()(0) = "4 1 0,4:0,1";
- SetNodeTensorAttr<string>("value", shapes_slices_val1,
- tensor_shapes_slices_node1);
-
- NodeDef* restore_node1 = CreateNode(
- "save/RestoreV2", "RestoreV2",
- {save_const_node, tensor_names_node1, tensor_shapes_slices_node1},
- &graph_def);
- CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def);
+ CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def);
w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def);
zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor",
@@ -349,21 +360,7 @@ class SparsifyGatherTest : public ::testing::Test {
assign_node2 = CreateNode("w2/part_1/Assign", "Assign",
{w_node2, zeros_node2}, &graph_def);
- NodeDef* tensor_names_node2 =
- CreateNode("save/RestoreV2_1/tensor_names", "Const", {}, &graph_def);
- NodeDef* tensor_shapes_slices_node2 = CreateNode(
- "save/RestoreV2_1/shape_and_slices", "Const", {}, &graph_def);
-
- Tensor shapes_slices_val2(DT_STRING, TensorShape({1}));
- shapes_slices_val2.flat<string>()(0) = "4 1 0,4:0,1";
- SetNodeTensorAttr<string>("value", shapes_slices_val2,
- tensor_shapes_slices_node2);
-
- NodeDef* restore_node2 = CreateNode(
- "save/RestoreV2_1", "RestoreV2",
- {save_const_node, tensor_names_node2, tensor_shapes_slices_node2},
- &graph_def);
- CreateNode("save/Assign_1", "Assign", {w_node2, restore_node2},
+ CreateNode("save/Assign_1", "Assign", {w_node2, restore_node},
&graph_def);
BundleWriter writer(Env::Default(), checkpoint_path);