diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-05 13:50:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 13:54:16 -0700 |
commit | 11caab3c138d06390344c88a4149f1897e3d780d (patch) | |
tree | 33aac05bfa4fdf6cd81998232268c79000f0d4d0 /tensorflow/compiler/jit | |
parent | c9c8de440213355ea4a4d3577fd068d418678d38 (diff) |
[XLA] Make tensorflow/compiler use absl::{StrCat,string_view,InlinedVector} consistently
StringPiece is an alias for absl::string_view, InlinedVector is aliased to absl::InlinedVector. StrCat is compatible, so swapping it out is safe.
PiperOrigin-RevId: 211691840
Diffstat (limited to 'tensorflow/compiler/jit')
18 files changed, 148 insertions, 145 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index df81f3c23e..de7cd26d1d 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -410,6 +410,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -566,6 +567,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 82aa03810b..9128b48da3 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -154,7 +154,7 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } @@ -185,7 +185,7 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } @@ -206,7 +206,7 @@ class NotPredicate : public Predicate { operands_({operand}) {} string ToString() const override { - return strings::StrCat("~", operand()->ToString()); + return absl::StrCat("~", operand()->ToString()); } Kind kind() const override { return Kind::kNot; } @@ -240,8 +240,8 @@ class AndRecurrencePredicate : public Predicate { Predicate* step() const { return operands_[1]; } string ToString() const override { - return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -267,7 +267,7 @@ class SymbolPredicate : public Predicate { must_be_true_(must_be_true) {} string ToString() const override { - return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 2788102620..ae7a22f451 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -755,7 +755,7 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -790,7 +790,7 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -950,16 +950,15 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", - oc_subgraph_name, "_host_compute"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), kHostComputeOp); builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); builder.Attr("ancestors", host_compute_ancestors); - builder.Attr("key", - strings::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -1017,8 +1016,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; - NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), - "NoOp"); + NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); Status s = builder.Finalize(&seq_def); @@ -1091,10 +1089,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); + dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), + *graph_, library); + dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), + fdef); } if (!reuse_existing_functions || library->Find(name) == nullptr) { @@ -1130,8 +1128,8 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shapes", shapes); } else { string inference_graph_name = - strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, - "_", outside_compilation_subgraph_name); + absl::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); FunctionDef fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); @@ -1155,10 +1153,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; dump_graph::DumpGraphToFile( - strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, library); dump_graph::DumpFunctionDefToFile( - strings::StrCat("replace_encapsulate_fdef_", name), fdef); + absl::StrCat("replace_encapsulate_fdef_", name), fdef); } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); @@ -1186,8 +1184,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - strings::StrCat(call_node_def_.name(), "_key_placeholder"), - "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1221,16 +1218,16 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_recv"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); builder.Device(device_); builder.Attr("Toutputs", dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); builder.Attr(group_attribute, subgraph_name); builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); @@ -1276,13 +1273,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_send"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), kSendFromHostOp); builder.Device(device_); builder.Attr("Tinputs", dtypes); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); @@ -1516,7 +1513,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { // Dump subgraphs. for (auto& entry : subgraphs_) { dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } } @@ -2052,7 +2049,7 @@ struct PathDetails { struct SubgraphAndClusterHash { inline std::size_t operator()(const SubgraphAndCluster& v) const { return hash<string>()( - strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + absl::StrCat(v.subgraph, v.outside_compilation_cluster)); } }; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 7bc0ef0303..49958093b8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "absl/strings/match.h" @@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, FunctionDef* fdef = library->add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( *graph, - strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + absl::StrCat("_outside_compilation_shape_inference_", name_suffix), fdef)); return Status::OK(); } @@ -65,18 +66,18 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a, const auto iter = b.find(elt_a.first); if (iter == b.end()) { if (diff) { - *diff = strings::StrCat( - map_name, " expected: contains element with key '", - key_to_string(elt_a.first), "' got: map has no such element"); + *diff = absl::StrCat(map_name, " expected: contains element with key '", + key_to_string(elt_a.first), + "' got: map has no such element"); } return false; } if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { - *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), "' has value '", - value_to_string(elt_a.second), "' got: '", - value_to_string(iter->second), "'"); + *diff = absl::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), "' has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); } return false; } @@ -85,9 +86,9 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a, const auto iter = a.find(elt_b.first); if (iter == a.end()) { if (diff) { - *diff = strings::StrCat(map_name, " got: contains element with key '", - key_to_string(elt_b.first), - "' expected: map has no such element"); + *diff = absl::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); } return false; } @@ -99,25 +100,25 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, const string& diff_preamble, string* diff) { if (a.op() != b.op()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected op '", a.op(), "' got '", b.op()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); } return false; } if (a.device() != b.device()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected device '", a.device(), "' got '", - b.device()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); } return false; } if (a.input_size() != b.input_size()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected ", a.input_size(), " inputs got ", - b.input_size(), " expected:\n", a.DebugString(), - "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } @@ -127,10 +128,10 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, if (absl::StartsWith(a.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) { if (diff) { - *diff = strings::StrCat( - diff_preamble, " mismatch for node ", a.name(), " input ", i, - ", expected control input ", a.input(i), " got ", b.input(i), - " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected control input ", + a.input(i), " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -138,19 +139,19 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, control_input_b.insert(b.input(i)); } else if (a.input(i) != b.input(i)) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " input ", i, ", expected ", a.input(i), - " got ", b.input(i), " expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), " got ", + b.input(i), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } } if (control_input_a != control_input_b) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " control inputs differ expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -170,18 +171,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return av.DebugString() == bv.DebugString(); } }, - strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), - diff); + absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Signature mismatch for function ", - a.signature().name(), ", expected:\n", - a.signature().DebugString(), "\ngot:\n", - b.signature().DebugString()); + *diff = + absl::StrCat("Signature mismatch for function ", a.signature().name(), + ", expected:\n", a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } @@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, - strings::StrCat("attr mismatch for function ", a.signature().name()), + absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } @@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const string& av, const string& bv) { return av == bv; }, - strings::StrCat("ret mismatch for function ", a.signature().name()), + absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; } @@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (a.node_def(i).name() == b.node_def(j).name()) { if (!EqualFunctionNodeDef( a.node_def(i), b.node_def(j), - strings::StrCat("Function ", a.signature().name()), diff)) { + absl::StrCat("Function ", a.signature().name()), diff)) { return false; } found = true; @@ -220,9 +220,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", expected: has node '", a.node_def(i).name(), - "' got: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); } return false; } @@ -237,9 +237,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", got: has node '", b.node_def(i).name(), - "' expected: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); } return false; } @@ -258,8 +258,8 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, auto it = actual_index.find(expected_function.signature().name()); if (it == actual_index.end()) { if (diff) { - *diff = strings::StrCat("Did not find expected function '", - expected_function.signature().name(), "'"); + *diff = absl::StrCat("Did not find expected function '", + expected_function.signature().name(), "'"); } return false; } @@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, if (!actual_index.empty()) { if (diff != nullptr) { - *diff = strings::StrCat("Found unexpected function '", - actual_index.begin()->second->signature().name(), - "'"); + *diff = + absl::StrCat("Found unexpected function '", + actual_index.begin()->second->signature().name(), "'"); } return false; } @@ -420,10 +420,9 @@ Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, const string& oc_cluster, absl::Span<const DataType> dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_recv"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -440,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster, const std::vector<ops::NodeOut>& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_send"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -682,8 +680,8 @@ std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) { for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( - strings::StrCat(edge->src()->name(), ":", edge->src_output()), - strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); + absl::StrCat(edge->src()->name(), ":", edge->src_output()), + absl::StrCat(edge->dst()->name(), ":", edge->dst_input())); } std::sort(edges.begin(), edges.end()); return edges; diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 676f71a75a..8212956adf 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -14,6 +14,7 @@ cc_library( hdrs = ["graphcycles.h"], deps = [ "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 805bbc62c1..756377bd95 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -34,7 +34,7 @@ limitations under the License. #include <algorithm> #include <unordered_set> -#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -44,7 +44,7 @@ namespace { typedef std::unordered_set<int32> NodeSet; template <typename T> struct VecStruct { - typedef gtl::InlinedVector<T, 4> type; + typedef absl::InlinedVector<T, 4> type; }; template <typename T> using Vec = typename VecStruct<T>::type; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4e4abade32..44caf0be52 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -43,7 +43,6 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -617,7 +616,7 @@ Status MarkForCompilationPass::Run( } static string RatioToString(int numerator, int denominator) { - return strings::Printf("%d / %d (%.2f%%)", numerator, denominator, + return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } @@ -626,14 +625,14 @@ static void VLogClusteringSummary(const Graph& g) { return; } - std::map<StringPiece, int> cluster_name_to_size; - std::map<StringPiece, std::map<StringPiece, int>> + std::map<absl::string_view, int> cluster_name_to_size; + std::map<absl::string_view, std::map<absl::string_view, int>> cluster_name_to_op_histogram; - std::map<StringPiece, int> unclustered_op_histogram; + std::map<absl::string_view, int> unclustered_op_histogram; int clustered_node_count = 0; for (Node* n : g.nodes()) { - absl::optional<StringPiece> cluster_name = GetXlaClusterForNode(*n); + absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n); if (cluster_name) { clustered_node_count++; cluster_name_to_size[*cluster_name]++; @@ -650,7 +649,7 @@ static void VLogClusteringSummary(const Graph& g) { << RatioToString(clustered_node_count, g.num_nodes()); for (const auto& cluster_name_size_pair : cluster_name_to_size) { - StringPiece cluster_name = cluster_name_size_pair.first; + absl::string_view cluster_name = cluster_name_size_pair.first; int size = cluster_name_size_pair.second; VLOG(2) << " " << cluster_name << " " << RatioToString(size, g.num_nodes()); @@ -670,14 +669,15 @@ static void VLogClusteringSummary(const Graph& g) { } struct EdgeInfo { - StringPiece node_name; - absl::optional<StringPiece> cluster_name; + absl::string_view node_name; + absl::optional<absl::string_view> cluster_name; - StringPiece GetClusterName() const { + absl::string_view GetClusterName() const { return cluster_name ? *cluster_name : "[none]"; } - std::pair<StringPiece, absl::optional<StringPiece>> AsPair() const { + std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair() + const { return {node_name, cluster_name}; } @@ -686,19 +686,21 @@ static void VLogClusteringSummary(const Graph& g) { } }; - using EdgeInfoMap = std::map<StringPiece, std::map<EdgeInfo, int64>>; + using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>; EdgeInfoMap incoming_edge_infos; EdgeInfoMap outgoing_edge_infos; - std::set<StringPiece> cluster_names_to_print; + std::set<absl::string_view> cluster_names_to_print; for (const Edge* e : g.edges()) { const Node* from = e->src(); - absl::optional<StringPiece> from_cluster_name = GetXlaClusterForNode(*from); + absl::optional<absl::string_view> from_cluster_name = + GetXlaClusterForNode(*from); const Node* to = e->dst(); - absl::optional<StringPiece> to_cluster_name = GetXlaClusterForNode(*to); + absl::optional<absl::string_view> to_cluster_name = + GetXlaClusterForNode(*to); if (to_cluster_name == from_cluster_name) { continue; @@ -721,9 +723,9 @@ static void VLogClusteringSummary(const Graph& g) { VLOG(2) << " [none]"; } - auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name, + auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name, const EdgeInfoMap& edge_info_map, - StringPiece desc) { + absl::string_view desc) { auto it = edge_info_map.find(cluster_name); if (it != edge_info_map.end()) { VLOG(2) << " " << it->second.size() << " " << desc << " edges"; @@ -737,7 +739,7 @@ static void VLogClusteringSummary(const Graph& g) { } }; - for (StringPiece cluster_name : cluster_names_to_print) { + for (absl::string_view cluster_name : cluster_names_to_print) { VLOG(2) << " ** Cluster " << cluster_name; print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, "incoming"); @@ -966,7 +968,7 @@ Status MarkForCompilationPass::RunImpl( string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 807ab51fd3..9473ac0a4c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index a8f09bfa50..584c963f71 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -30,7 +31,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, MemoryTypeVector input_mtypes, output_mtypes; for (Node* n : post_order) { - absl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n); + absl::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n); if (!from_cluster) { continue; } @@ -79,7 +80,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, // Check if `dst` is in a different cluster, unclustered, or about to be // partially declustered (here we rely on the post-order traversal order). // If yes, decluster `n` to avoid the device-to-host memcpy. - absl::optional<StringPiece> dst_cluster = + absl::optional<absl::string_view> dst_cluster = result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst); if (from_cluster != dst_cluster) { CHECK(result->insert(n).second); @@ -91,15 +92,16 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, } Status PartiallyDeclusterNode(Graph* graph, Node* n) { - StringPiece cluster_name = *GetXlaClusterForNode(*n); - gtl::InlinedVector<const Edge*, 6> out_edges_to_clone; + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + absl::InlinedVector<const Edge*, 6> out_edges_to_clone; for (const Edge* out_edge : n->out_edges()) { if (out_edge->IsControlEdge()) { continue; } Node* dst = out_edge->dst(); - absl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst); + absl::optional<absl::string_view> dst_cluster_name = + GetXlaClusterForNode(*dst); if (dst_cluster_name != cluster_name) { out_edges_to_clone.push_back(out_edge); } @@ -108,7 +110,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { CHECK(!out_edges_to_clone.empty()) << n->DebugString(); NodeDef ndef = n->def(); - ndef.set_name(strings::StrCat(n->name(), "/declustered")); + ndef.set_name(absl::StrCat(n->name(), "/declustered")); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 1ba4a5ef73..56e35c0059 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -165,7 +165,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { using ResourceOp = std::pair<int, XlaResourceOpKind>; string ResourceOpToString(const ResourceOp& resource_op) { - return strings::StrCat( + return absl::StrCat( resource_op.first, ": ", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); } @@ -257,11 +257,11 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { std::vector<string> elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); - return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); + return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { - return strings::StrCat( + return absl::StrCat( "[", n.name(), ": ", n.type_string(), "(", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); } diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 4f2fabd658..03380e9406 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include <unordered_map> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" @@ -52,8 +53,8 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, }; string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); + absl::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); path.resize(path_size); for (int32 node_id : path) { string ascii_art; @@ -64,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } else { ascii_art = "+-- "; } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + absl::StrAppend(&description, ascii_art, node_name(node_id), "\n"); } return description; } @@ -186,7 +187,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } -absl::optional<StringPiece> GetXlaClusterForNode(const Node& node) { +absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) { const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); if (attr_value == nullptr) { return absl::nullopt; diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index b0439a63ca..17ae510a0e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -47,7 +47,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. -absl::optional<StringPiece> GetXlaClusterForNode(const Node& node); +absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index dcb0b3240a..3aa9e9c7ed 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -67,12 +67,12 @@ string XlaCompilationCache::DebugString() { string XlaCompilationCache::SignatureDebugString(const Signature& sig) { string result = sig.name; for (const auto& a : sig.arg_types) { - strings::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + absl::StrAppend(&result, ",", DataTypeString(a.first), + a.second.DebugString()); } for (const auto& v : sig.arg_values) { - strings::StrAppend(&result, "; ", v.DebugString()); + absl::StrAppend(&result, "; ", v.DebugString()); } return result; } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index f31879a2bc..51797def04 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -148,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { } const DeviceAttributes attrs = Device::BuildDeviceAttributes( - strings::StrCat(name_prefix, "/device:", device_name, ":", - device_ordinal), + absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), - strings::StrCat("device: ", device_name, " device")); + absl::StrCat("device: ", device_name, " device")); device->reset( new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ee07c5c964..af83c792e5 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -203,7 +203,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { @@ -339,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 2e7445340c..df82421294 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -57,7 +57,7 @@ class XlaTransferManager { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, @@ -111,7 +111,7 @@ class XlaDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 07cfab6151..bc0db558d8 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -326,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 4c9bb2e27b..d95da63405 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -122,7 +122,7 @@ class XlaTensor { std::shared_ptr<se::Event> definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. - gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_); mutex mu_; }; |