aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-08-23 19:17:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 19:21:32 -0700
commit9d51d5914084f4cd65ea18c4da45dfb64f2945b6 (patch)
treee7b1cef2df28d83db67f9734d2ad477c89a5fcca /tensorflow/compiler/aot
parent52acfd08111acc6668aa23498688a3346d1c5334 (diff)
Use absl functions instead of str_util within tf2xla.
PiperOrigin-RevId: 210040583
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/BUILD2
-rw-r--r--tensorflow/compiler/aot/codegen.cc13
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc6
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc1
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc7
5 files changed, 16 insertions, 13 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 58dc91f74d..59b961cdd9 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -72,6 +72,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
],
@@ -100,6 +101,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index d9edbb2849..e77a8fecf0 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
}
rewrites->push_back({"{{I}}", strings::StrCat(i)});
rewrites->push_back({"{{TYPE}}", type});
- rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
+ rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
return Status::OK();
@@ -159,7 +159,8 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
string RewriteWithName(const string& name, string code,
const std::vector<std::pair<string, string>>& rewrites) {
absl::StrReplaceAll(rewrites, &code);
- return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true);
+ absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
+ return code;
}
// Generate methods for args (inputs).
@@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
- {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
+ {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
- str_util::Join(metadata_result.header_variable_decls, "\n")},
+ absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
@@ -595,7 +596,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
{"{{BUFFER_INFOS_AS_STRING}}",
- str_util::Join(buffer_infos_as_strings, ",\n")}};
+ absl::StrJoin(buffer_infos_as_strings, ",\n")}};
absl::StrReplaceAll(rewrites, header);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 60d59ae996..e3a53edb73 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -18,13 +18,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/match.h"
#include "llvm/Support/TargetSelect.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -34,9 +34,9 @@ namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
-void ExpectErrorContains(const Status& status, StringPiece str) {
+void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index a2b824da17..dd2b151098 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -33,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 839e1588b7..f3c44e9dda 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@@ -34,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,7 +56,7 @@ const char kUsageHeader[] =
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
- if (str_util::EndsWith(fname, ".pbtxt")) {
+ if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
@@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) {
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
- std::cout << str_util::Join(nodes, ",");
+ std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}