aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/sparsify_gather.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/graph_transforms/sparsify_gather.cc')
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 701e350fc3..cc82100148 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -88,7 +89,7 @@ void CreateConstNode(const Tensor& tensor, const string& name,
string GetMonolithicTensorKey(const string& tensor_slice_name) {
std::vector<string> names = Split(tensor_slice_name, "/");
- if (StringPiece(names[names.size() - 1]).starts_with("part_")) {
+ if (str_util::StartsWith(names[names.size() - 1], "part_")) {
CHECK_GE(names.size(), 2);
names.pop_back();
}
@@ -102,8 +103,8 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
for (const auto& node : input_graph_def.node()) {
std::vector<string> node_name_parts = Split(node.name(), "/");
if (node_name_parts.size() == 2 &&
- StringPiece(node_name_parts[0]).starts_with("save") &&
- StringPiece(node_name_parts[1]).starts_with("Assign") &&
+ str_util::StartsWith(node_name_parts[0], "save") &&
+ str_util::StartsWith(node_name_parts[1], "Assign") &&
node.input(0) == target_name) {
restore_node_name = node.input(1);
break;