aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-16 12:24:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-18 10:39:07 -0800
commit2ea90304ffb2cd338b1dfc5a3e26a3373ce1fe98 (patch)
tree8b07d3e571d7ebcf9135a08c5206abb32d9feb64 /tensorflow/tools/graph_transforms
parent04a14ef4f98ffa921095590d7b86490b3d2b19c6 (diff)
Updating sparsify gather to work with core estimators.
PiperOrigin-RevId: 179306398
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 20d443c7e9..96324d0dea 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -89,7 +89,10 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
string* shape_slice_string) {
string restore_node_name;
for (const auto& node : input_graph_def.node()) {
- if (StringPiece(node.name()).starts_with("save/Assign") &&
+ std::vector<string> node_name_parts = str_util::Split(node.name(), "/");
+ if (node_name_parts.size() == 2 &&
+ StringPiece(node_name_parts[0]).starts_with("save") &&
+ StringPiece(node_name_parts[1]).starts_with("Assign") &&
node.input(0) == tensor_name) {
restore_node_name = node.input(1);
break;