diff options
author | 2018-08-23 19:17:40 -0700 | |
---|---|---|
committer | 2018-08-23 19:21:32 -0700 | |
commit | 9d51d5914084f4cd65ea18c4da45dfb64f2945b6 (patch) | |
tree | e7b1cef2df28d83db67f9734d2ad477c89a5fcca /tensorflow/compiler/aot | |
parent | 52acfd08111acc6668aa23498688a3346d1c5334 (diff) |
Use absl functions instead of str_util within tf2xla.
PiperOrigin-RevId: 210040583
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r-- | tensorflow/compiler/aot/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/aot/codegen.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/aot/codegen_test.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tests/tfcompile_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tfcompile_main.cc | 7 |
5 files changed, 16 insertions, 13 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 58dc91f74d..59b961cdd9 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -72,6 +72,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@llvm//:support", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep ], @@ -100,6 +101,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index d9edbb2849..e77a8fecf0 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, } rewrites->push_back({"{{I}}", strings::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); - rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); return Status::OK(); @@ -159,7 +159,8 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, string RewriteWithName(const string& name, string code, const std::vector<std::pair<string, string>>& rewrites) { absl::StrReplaceAll(rewrites, &code); - return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); + absl::StrReplaceAll({{"{{NAME}}", name}}, &code); + return code; } // Generate methods for args (inputs). @@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, - {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, + {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, + absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, @@ -595,7 +596,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", - str_util::Join(buffer_infos_as_strings, ",\n")}}; + absl::StrJoin(buffer_infos_as_strings, ",\n")}}; absl::StrReplaceAll(rewrites, header); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 60d59ae996..e3a53edb73 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/match.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -34,9 +34,9 @@ namespace { using ::tensorflow::cpu_function_runtime::BufferInfo; -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index a2b824da17..dd2b151098 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 839e1588b7..f3c44e9dda 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,6 +18,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +56,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (str_util::EndsWith(fname, ".pbtxt")) { + if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) { for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } - std::cout << str_util::Join(nodes, ","); + std::cout << absl::StrJoin(nodes, ","); return Status::OK(); } |