diff options
Diffstat (limited to 'tensorflow/tools/graph_transforms/sparsify_gather.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/sparsify_gather.cc | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index 701e350fc3..cc82100148 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" @@ -88,7 +89,7 @@ void CreateConstNode(const Tensor& tensor, const string& name, string GetMonolithicTensorKey(const string& tensor_slice_name) { std::vector<string> names = Split(tensor_slice_name, "/"); - if (StringPiece(names[names.size() - 1]).starts_with("part_")) { + if (str_util::StartsWith(names[names.size() - 1], "part_")) { CHECK_GE(names.size(), 2); names.pop_back(); } @@ -102,8 +103,8 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def, for (const auto& node : input_graph_def.node()) { std::vector<string> node_name_parts = Split(node.name(), "/"); if (node_name_parts.size() == 2 && - StringPiece(node_name_parts[0]).starts_with("save") && - StringPiece(node_name_parts[1]).starts_with("Assign") && + str_util::StartsWith(node_name_parts[0], "save") && + str_util::StartsWith(node_name_parts[1], "Assign") && node.input(0) == target_name) { restore_node_name = node.input(1); break; |