diff options
616 files changed, 12251 insertions, 4605 deletions
@@ -100,7 +100,7 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | | **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | -| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)<br>[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | ## For more information diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 2220d0786d..59b961cdd9 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -32,7 +32,6 @@ cc_library( deps = [ ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -56,6 +55,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -72,6 +72,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@llvm//:support", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep ], @@ -100,6 +101,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -195,6 +197,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", "@llvm//:target", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 44291d977f..e77a8fecf0 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,9 +20,10 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" -#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, } rewrites->push_back({"{{I}}", strings::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); - rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); return Status::OK(); @@ -158,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, // text-templating mechanism. string RewriteWithName(const string& name, string code, const std::vector<std::pair<string, string>>& rewrites) { - str_util::ReplaceAllPairs(&code, rewrites); - return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); + absl::StrReplaceAll(rewrites, &code); + absl::StrReplaceAll({{"{{NAME}}", name}}, &code); + return code; } // Generate methods for args (inputs). @@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, - {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, + {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, + absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, @@ -595,8 +596,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", - str_util::Join(buffer_infos_as_strings, ",\n")}}; - str_util::ReplaceAllPairs(header, rewrites); + absl::StrJoin(buffer_infos_as_strings, ",\n")}}; + absl::StrReplaceAll(rewrites, header); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 60d59ae996..e3a53edb73 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/match.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -34,9 +34,9 @@ namespace { using ::tensorflow::cpu_function_runtime::BufferInfo; -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 8fb2fad31c..1401aae758 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -19,6 +19,7 @@ limitations under the License. #include <string> #include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" @@ -27,7 +28,6 @@ limitations under the License. #include "llvm/Support/TargetRegistry.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/util.h" @@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, " return proto;\n" " }()"; - str_util::ReplaceAllPairs( - &code, + return absl::StrReplaceAll( + code, { {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, }); - return code; } static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine, @@ -97,7 +96,7 @@ static StatusOr<std::unique_ptr<llvm::TargetMachine>> GetTargetMachineFromTriple(StringPiece target_triple) { std::string error; std::string normalized_triple = - llvm::Triple::normalize(AsStringRef(target_triple)); + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); const llvm::Target* target = llvm::TargetRegistry::lookupTarget(normalized_triple, error); if (target == nullptr) { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 0ecc3feeb6..7364d63b53 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -226,5 +226,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 0c0c676ece..dd2b151098 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #define EIGEN_USE_CUSTOM_THREAD_POOL +#include "absl/strings/str_split.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -546,7 +546,7 @@ TEST(TFCompileTest, HloProfiling) { VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; std::vector<string> hlo_profile_lines = - tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + absl::StrSplit(hlo_profile_as_string, '\n'); auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 839e1588b7..f3c44e9dda 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,6 +18,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +56,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (str_util::EndsWith(fname, ".pbtxt")) { + if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) { for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } - std::cout << str_util::Join(nodes, ","); + std::cout << absl::StrJoin(nodes, ","); return Status::OK(); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 2466c218c8..df81f3c23e 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -311,6 +311,51 @@ tf_cc_test( ) cc_library( + name = "resource_operation_safety_analysis", + srcs = ["resource_operation_safety_analysis.cc"], + hdrs = ["resource_operation_safety_analysis.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "resource_operation_safety_analysis_test", + srcs = ["resource_operation_safety_analysis_test.cc"], + deps = [ + ":common", + ":resource_operation_safety_analysis", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", + ], +) + +cc_library( name = "compilation_passes", srcs = [ "build_xla_launch_ops_pass.cc", @@ -335,11 +380,10 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", - "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -351,6 +395,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", ], ) @@ -359,6 +404,7 @@ cc_library( srcs = ["xla_cluster_util.cc"], hdrs = ["xla_cluster_util.h"], deps = [ + ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -437,6 +483,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -448,6 +495,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -528,6 +576,9 @@ tf_cuda_cc_test( ":common", ":xla_cluster_util", ":xla_fusion_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:graph", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 1b1ce78ed2..a7f8a5613c 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -126,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, const DataTypeVector& arg_types = (*fbody)->arg_types; std::vector<bool> const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0ca0f949dc..fe28502f69 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -153,7 +154,7 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } @@ -182,7 +183,7 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index cc9f102398..28a56044d5 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f150bf1819..2788102620 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -2504,7 +2504,8 @@ Status EncapsulateSubgraphsPass::Run( const int num_args = input_permutation->size(); std::vector<bool> const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr)); DataTypeVector arg_types(num_args); TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index c0543a0079..b3600fc48b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, std::unordered_set<string> control_input_a; std::unordered_set<string> control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (str_util::StartsWith(a.input(i), "^")) { - if (!str_util::StartsWith(b.input(i), "^")) { + if (absl::StartsWith(a.input(i), "^")) { + if (!absl::StartsWith(b.input(i), "^")) { if (diff) { *diff = strings::StrCat( diff_preamble, " mismatch for node ", a.name(), " input ", i, @@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 8f78c110cb..253a5d2547 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -29,16 +29,3 @@ cc_library( ], alwayslink = 1, ) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - visibility = ["//tensorflow/compiler/jit:friends"], - deps = [ - "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc deleted file mode 100644 index bd4eefbc0b..0000000000 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace { - -// Inputs 2*N tensors, outputs the first N inputs. -// Logs errors if input tensor i and i + N are not (near) identical -// in any position. -class ParallelCheckOp : public OpKernel { - public: - explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - template <typename T> - int CompareTensors(DataType dtype, const char* v0, const char* v1, - int64 num_elts, int input_idx) { - int failed = 0; - const T* p0 = reinterpret_cast<const T*>(v0); - const T* p1 = reinterpret_cast<const T*>(v1); - double rtol; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), - &rtol)) { - LOG(ERROR) << "can't convert parallel_check_rtol " - << flags->parallel_check_rtol << " to double"; - } - double atol; - if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), - &atol)) { - LOG(ERROR) << "can't convert parallel_check_atol " - << flags->parallel_check_atol << " to double"; - } - for (int i = 0; i < num_elts; ++i) { - bool ok = (p0[i] == p1[i]); - VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; - if (!ok) { - if (std::is_same<T, float>::value || std::is_same<T, double>::value) { - float tolerance = - std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); - T diff = p0[i] - p1[i]; - if (diff < 0) diff = 0 - diff; - ok = (diff <= tolerance); - } - if (ok) continue; - LOG(ERROR) << "Op " << name() << " fails equality at output " - << input_idx << " type " << DataTypeString(dtype) - << " element " << i << ": std_val=" << p0[i] - << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); - if (++failed > 10) break; - } - } - return failed; - } - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << name(); - const int num_pairs = ctx->num_inputs() / 2; - for (int i = 0; i < num_pairs; ++i) { - CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); - Tensor t0 = ctx->input(i); - Tensor t1 = ctx->input(i + num_pairs); - int64 num_elts = t0.NumElements(); - CHECK_EQ(num_elts, t1.NumElements()); - - // Compare inputs elementwise for near-exact equality. - const char* v0 = t0.tensor_data().data(); - const char* v1 = t1.tensor_data().data(); - int failed = 0; - switch (ctx->input_dtype(i)) { - case DT_INT32: - failed = - CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_INT64: - failed = - CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_FLOAT: - failed = - CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_DOUBLE: - failed = - CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_BOOL: - failed = - CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - default: - LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); - } - if (failed > 0) { - LOG(ERROR) << "check failed for " << name() << " output " << i - << " num_elts: " << num_elts; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (flags->parallel_check_failfast) { - LOG(QFATAL) << "failfast on first parallel-check failure"; - } - } else { - VLOG(1) << "check passed for " << name() << " output " << i - << " num_elts: " << num_elts; - } - - // Propagate the std value. - if (IsRefType(ctx->input_dtype(i))) { - ctx->forward_ref_input_to_ref_output(i, i); - } else { - ctx->set_output(i, ctx->input(i)); - } - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); -}; - -REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), - ParallelCheckOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index ddb27a38ae..fde4135bf7 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -187,7 +187,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, &compile_options)); + &kernel, &executable, compile_options)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 11bd5eec23..518c39ec15 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -27,7 +27,9 @@ limitations under the License. #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -74,18 +77,40 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } +bool HasResourceOutput(const Node& node) { + return std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +bool HasResourceInput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end(); +} + +// Returns true if `node` is a resource operation recognized by tf2xla that +// operates on something other than resource variables. +bool IsNonResourceVarResourceOp(const Node& node) { + // TODO(b/112837194): We can't cluster these because we only support + // snapshotting resource variables (and we can't e.g. snapshot stacks). This + // limitation may be fixable with some work. + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() != XlaResourceKind::kVariable; +} + // Make sure we don't recurse infinitely on recursive functions. const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime); // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. bool IsCompilableWhile(const Node& while_node, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { const NameAttrList* name_attr; NodeDef call; @@ -100,7 +125,8 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_cond"); call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop condition: " << cond_func; return false; @@ -115,7 +141,8 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_body"); call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop body: " << body_func; return false; @@ -127,7 +154,8 @@ bool IsCompilableWhile(const Node& while_node, // Every operator in the function must be compilable for a function to be // compilable. bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { if (depth > kMaxRecursionDepth) { VLOG(2) << "Rejecting " << call_def.op() @@ -167,12 +195,17 @@ bool IsCompilableCall(const NodeDef& call_def, if (node->type_string() == "_Arg" || node->type_string() == "_Retval") continue; if (node->type_string() == "While") { - // Handle functional While loop (not in open source build). - return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); + // Handle functional While loop. + return IsCompilableWhile(*node, jit_device_type, allow_resource_ops, + depth + 1, lib_runtime); + } + if (!allow_resource_ops && + (HasResourceInput(*node) || HasResourceOutput(*node))) { + return false; } if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, depth + 1, - lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops, + depth + 1, lib_runtime)) { VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " << node->name() << ": " << node->def().ShortDebugString(); return false; @@ -343,6 +376,10 @@ Status FindCompilationCandidates( flib_def, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + std::vector<bool> compile_time_const_nodes(graph.num_node_ids(), false); + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, + &compile_time_const_nodes)); int64& fuel = legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; @@ -386,19 +423,46 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, + registration->compile_resource_ops, 0, lib_runtime)) { VLOG(2) << "Rejecting " << node->name() << ": unsupported op " << node->type_string(); continue; } if (!registration->compile_resource_ops && - HasResourceInputOrOutput(*node)) { - VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { + // We don't have a way of returning values of type DT_RESOURCE from XLA + // computations so we avoid auto-clustering nodes producing DT_RESOURCE. + // XlaLaunchOp also cannot snapshot resources that are not resource + // variables so we avoid clustering resource operations that operate on + // non-resource variables. + VLOG(2) << "Rejecting: " << node->name() << ": resource output " << node->type_string(); continue; } + if (compile_time_const_nodes[node->id()] && + !registration->requires_compilation) { + const OpDef* op_def; + TF_RETURN_IF_ERROR( + OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + if (op_def->is_stateful()) { + // We need to be able to constant fold the nodes in + // compile_time_const_nodes given constant inputs (required by XLA) and + // therefore can't auto-cluster stateful ops since these can never be + // constant folded. + VLOG(2) << "Rejecting " << node->name() + << ": must-be-constant stateful op"; + continue; + } + } + // We don't auto-cluster functional control flow nodes containing resource + // operations because safety checks are trickier in this case. + // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not + // for CPU/GPU. if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { + !IsCompilableWhile(*node, jit_device_type, + registration->compile_resource_ops, 0, + lib_runtime)) { continue; } // _Arg nodes in a top-level function represent feeds. @@ -457,7 +521,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - return IsCompilableCall(ndef, jit_device_type, 0, flr); + + // We can always *compile* resource operations, even if we are sometimes + // unable to auto-cluster them. + const bool compile_resource_ops = true; + return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr); } Status MarkForCompilationPass::Run( @@ -609,6 +677,43 @@ static bool IsShapeConsumerOp(const Node& node) { node.type_string() == "Size"; } +static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { + // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then + // ignore it during resource operation safety analysis. We need this hack + // because of two reasons: + // + // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. + // 2. We don't support live-out values of type DT_RESOURCE and live-in values + // of type DT_RESOURCE that are not resource variables. + // + // Together these imply we cannot let resource variable safety analysis + // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different + // clusters: both of them will have to be clustered because of (1) and we + // won't be able to keep the edge between the two as neither the input to the + // second XLA cluster nor the output from the first XLA cluster are supported + // because of (2). + // + // TODO(b/113100872): This can be fixed if the TensorFlow representation for + // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then + // (2) would no longer hold. + + if (n.assigned_device_name().empty()) { + *ignore = false; + return Status::OK(); + } + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n.assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *ignore = true; + } else { + *ignore = registration->compile_resource_ops; + } + return Status::OK(); +} + // Sequence number generator to ensure clusters have unique names. static std::atomic<int64> cluster_sequence_num; @@ -637,6 +742,8 @@ Status MarkForCompilationPass::RunImpl( GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -675,7 +782,7 @@ Status MarkForCompilationPass::RunImpl( string to_scope; for (int to : cycles.Successors(from)) { if (to >= graph->num_node_ids()) { - // Node is a "frame" node that is present only in the cycle detection + // Node is a fictitious node that is present only in the cycle detection // graph. No clustering is possible. continue; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9d7ac0d609..807ab51fd3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" @@ -26,11 +28,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -48,9 +50,35 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) { ids[node->name()] = cluster; } } + + if (VLOG_IS_ON(2)) { + VLOG(2) << "Clusters:"; + for (const auto& p : ids) { + VLOG(2) << " " << p.first << " -> " << p.second; + } + } return ids; } +gtl::FlatMap<string, std::vector<string>> GetClusterSets( + const Graph& g, std::vector<string>* cluster_names = nullptr) { + CHECK(cluster_names == nullptr || cluster_names->empty()); + gtl::FlatMap<string, std::vector<string>> cluster_sets; + for (const auto& p : GetClusters(g)) { + cluster_sets[p.second].push_back(p.first); + } + for (auto& p : cluster_sets) { + if (cluster_names != nullptr) { + cluster_names->push_back(p.first); + } + std::sort(p.second.begin(), p.second.end()); + } + if (cluster_names != nullptr) { + std::sort(cluster_names->begin(), cluster_names->end()); + } + return cluster_sets; +} + TEST(XlaCompilationTest, Chains) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -501,38 +529,104 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } -REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); -REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); - namespace { +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} -class DummyOp : public XlaOpKernel { - using XlaOpKernel::XlaOpKernel; - void Compile(XlaOpKernelContext* ctx) override {} -}; - -REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); -REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id), + var_handle, value_to_write); + return assign_op.operation.node(); +} +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} } // namespace -TEST(XlaCompilationTest, Resources) { +TEST(XlaCompilationTest, ResourcesClusteringAllowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + FixupSourceAndSinkEdges(root.graph()); std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = - ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); - Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); - // We should not form clusters with resource ops by default. - Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); - Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); - ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } + TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. + gtl::FlatMap<string, std::vector<string>> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + gtl::FlatMap<string, std::vector<string>> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector<string> expected_clustered_nodes = {"AssignmentW", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::vector<string> cluster_names; + gtl::FlatMap<string, std::vector<string>> cluster_sets = + GetClusterSets(*graph, &cluster_names); + + ASSERT_EQ(cluster_sets.size(), 2); + + std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", + "ValueToAssignW0"}; + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); + + std::vector<string> expected_clustered_nodes_b = { + "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; + ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -562,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.ToString(), - "Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(absl::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { @@ -731,5 +825,27 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { EXPECT_EQ(clusters, expected_clusters); } +TEST(XlaCompilationTest, RandomShape) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("shape"), shape_shape, + ops::Const(root.WithOpName("minval"), 1), + ops::Const(root.WithOpName("maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_EQ(clusters["shape"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index c9e46bc147..13804c6a05 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -10,10 +10,3 @@ cc_library( deps = ["//tensorflow/core:framework"], alwayslink = 1, ) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - deps = ["//tensorflow/core:framework"], - alwayslink = 1, -) diff --git a/tensorflow/compiler/jit/ops/parallel_check_op.cc b/tensorflow/compiler/jit/ops/parallel_check_op.cc deleted file mode 100644 index db5c195578..0000000000 --- a/tensorflow/compiler/jit/ops/parallel_check_op.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("ParallelCheck") - .Attr("T: list(type) >= 0") - .Input("expected: T") - .Input("actual: T") - .Output("result: T") - .Doc(R"doc( -Op that compares two sets of inputs for near-identity, and propagates the first. -Inequality is logged to ERROR log. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 08a956e4c6..f61a955c22 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc new file mode 100644 index 0000000000..1ba4a5ef73 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -0,0 +1,336 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ALGORITHM OVERVIEW +// ================== +// +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// computes the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write +// dependencies. +// +// Specifically the result computed by this analysis contains the edge {W, R} +// iff all of these hold true: +// +// - In the graph (g - {edges from NextIteration to Merge}) there is a path +// from W to R. +// - IsEdgeSafe(W, R) == False [defined below] +// - W != R (note: some resource operations both read from and write to +// resource variables). +// +// The result is incorrect around loops because we ignore edges from +// NextIteration to Merge, but that should be fine because we don't cluster +// these edges. For instance, in: +// +// Init -----> Merge <-------+ +// | | +// v | +// Read | +// | | +// v | +// Write | +// | | +// v | +// NextIteration --+ +// +// we won't put (Read, Write) in the returned set. This is fine if +// auto-clustering can only cluster the Read->Write edge, but it is a problem if +// it clusters the Write->NextIteration->Merge->Read edges instead. The same +// problem is present for the functional version of the loop above. We rely on +// auto-clustering to not cluster control flow edges like NextIteration->Merge. +// This is enough to avoid the explicit-control-flow problem shown above. One +// way to think about this is that we only care about cases where two nodes, A +// and B, would normally have been put in the same cluster but cannot legally be +// in the same cluster because of resourcevar-dependencies. If A and B would +// normally have been put in the same cluster then all paths between A and B +// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo +// there could not have been a NextIteration->Merge edge between A and B since +// we don't cluster these edges. +// +// We also rely on auto-clustering to not cluster functional control flow nodes +// that contain resource operations. +// +// IMPLEMENTATION +// -------------- +// +// We traverse the graph minus backedges in reverse post order, mapping each +// node to the set of resource operation reaching that node. Since we visit +// producers before consumers, we can construct the set of reaching operations +// by taking the union of the operations reaching the input nodes. These +// "reaching resource operations" can then be used to create the pairs of +// incompatible nodes using `IsEdgeSafe`. + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { +// Returns true if `n` may call a function. +Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, + bool* out_result) { + if (flib_def->Contains(n.type_string())) { + *out_result = true; + } else { + *out_result = + std::any_of(n.def().attr().begin(), n.def().attr().end(), + [](const std::pair<string, AttrValue>& name_attr_pair) { + return name_attr_pair.second.has_func(); + }); + } + + return Status::OK(); +} + +// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is +// not a resource operation recognized by XLA then sets `out_resource_op_kind` +// to nullopt. +Status XlaResourceOpKindForNode( + const Node& n, const FunctionLibraryDefinition* flib_def, + const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, + absl::optional<XlaResourceOpKind>* out_resource_op_kind) { + bool should_ignore = false; + if (resource_ops_to_ignore) { + TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); + } + if (should_ignore) { + *out_resource_op_kind = absl::nullopt; + return Status::OK(); + } + + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); + if (op_info) { + *out_resource_op_kind = op_info->kind(); + return Status::OK(); + } + + // We conservatively assume that functions will both read and write resource + // variables. In the future we may consider doing some form of + // inter-procedural analysis. + bool may_call_function; + TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); + if (may_call_function) { + *out_resource_op_kind = XlaResourceOpKind::kReadWrite; + } else { + *out_resource_op_kind = absl::nullopt; + } + + return Status::OK(); +} + +// Returns true if a control or data dependence from a TensorFlow operation of +// resource op kind `from` to a TensorFlow operation of resource op kind `to` +// can be represented by an XLA cluster and needs no special handling around +// auto-jit. +bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { + // XLA clusters forces all reads to happen before all writes, which means the + // kinds of edges it can faithfully represent are: Read->Write, Read->Modify, + // Modify->Write, Read->Read, Write->Write. + // + // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write + // dependencies. + return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite; +} + +using ResourceOp = std::pair<int, XlaResourceOpKind>; + +string ResourceOpToString(const ResourceOp& resource_op) { + return strings::StrCat( + resource_op.first, ": ", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); +} + +// A copy-on-write set used to store the set of ResourceOps reaching a node in a +// TensorFlow graph. +// +// TODO(sanjoy): It may be useful to pull this out into its own header at some +// point. +class ResourceOpSet { + private: + using Impl = gtl::FlatSet<ResourceOp>; + + public: + ResourceOpSet() = default; + + // Adds all ResourceOp s in `other` to this set. + void Add(const ResourceOpSet& other) { + CHECK(!frozen_); + if (other.impl_ == impl_) { + other.frozen_ = true; + return; + } + + if (!impl_) { + other.frozen_ = true; + impl_ = other.impl_; + return; + } + + for (ResourceOp resource_op : other) { + Add(resource_op); + } + } + + void Add(const ResourceOp& resource_op) { + CHECK(!frozen_); + if (!IsCopy() && Contains(resource_op)) { + // We can avoid the copy if the item we want to insert already exists. + return; + } + + EnsureIsCopied(); + impl_->insert(resource_op); + } + + Impl::const_iterator begin() const { + return impl_ ? impl_->begin() : GetEmptyImpl()->begin(); + } + + Impl::const_iterator end() const { + return impl_ ? impl_->end() : GetEmptyImpl()->end(); + } + + bool Contains(const ResourceOp& resource_op) const { + return impl_ != nullptr && impl_->count(resource_op); + } + + private: + bool IsCopy() const { return storage_ != nullptr; } + + void EnsureIsCopied() { + if (storage_ == nullptr) { + storage_ = absl::make_unique<Impl>(); + for (ResourceOp op : *this) { + storage_->insert(op); + } + impl_ = storage_.get(); + } + } + + static Impl* GetEmptyImpl() { + static Impl* empty_impl = new Impl; + return empty_impl; + } + + Impl* impl_ = nullptr; + std::unique_ptr<Impl> storage_; + + // frozen_ is true if there is another set pointing to this set's impl_. We + // can no longer add elements to this set in that case since the sets pointing + // to this set expect the contents of this set to be stable. + mutable bool frozen_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet); +}; + +string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { + std::vector<string> elements_debug_string; + std::transform(resource_op_set.begin(), resource_op_set.end(), + std::back_inserter(elements_debug_string), ResourceOpToString); + return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); +} + +string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { + return strings::StrCat( + "[", n.name(), ": ", n.type_string(), "(", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); +} +} // namespace + +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, + std::vector<std::pair<int, int>>* result) { + CHECK(result->empty()); + + std::vector<Node*> rpo; + GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + auto resource_op_set_for_node = + absl::make_unique<ResourceOpSet[]>(g.num_node_ids()); + + const bool vlog = VLOG_IS_ON(2); + + for (Node* n : rpo) { + absl::optional<XlaResourceOpKind> op_kind; + TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( + *n, flib_def, resource_ops_to_ignore, &op_kind)); + + ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; + + // Merge the reaching resource operations for all the incoming edges to + // create the set of all possible resource ops reaching `n`. + for (const Edge* e : n->in_edges()) { + if (n->IsMerge() && e->src()->IsNextIteration()) { + // Ignore back-edges (see file comment). + continue; + } + + const ResourceOpSet& incoming_op_set = + resource_op_set_for_node[e->src()->id()]; + resource_op_set->Add(incoming_op_set); + } + + // Add to the "incompatible resource ops" set if necessary. + if (op_kind) { + for (ResourceOp incoming_op : *resource_op_set) { + if (IsEdgeSafe(incoming_op.second, *op_kind)) { + continue; + } + + if (vlog) { + VLOG(2) << "Unsafe edge: " + << NodeToString(*g.FindNodeId(incoming_op.first), + incoming_op.second) + << " -> " << NodeToString(*n, *op_kind); + } + result->push_back({incoming_op.first, n->id()}); + } + + resource_op_set->Add({n->id(), *op_kind}); + } + + if (vlog) { + VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set); + } + } + + std::sort(result->begin(), result->end()); + CHECK(std::unique(result->begin(), result->end()) == result->end()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h new file mode 100644 index 0000000000..ae8cfeecad --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// returns the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// The restrictions are not transitive: it is fine to put A and C in the same +// cluster even if the returned set contains (A,B) and (B,C). +// +// In other words, if these pairs are seen as edges in an undirected graph of +// the nodes in `g` then auto-clustering is at least as constrained as the graph +// coloring problem on this graph. +// +// +// For instance if we auto-cluster all operations in this TensorFlow graph: +// +// ReadVariablepOp0 -> ReadVariableOp1 +// | +// v +// AssignVariableOp0 -> AssignVariableOp1 +// +// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the +// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for +// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads +// all the resource variables when the cluster starts executing without any +// particular ordering between them; same holds for the AssignVariableOp0 -> +// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will +// be respected by XlaLaunchOp though because all reads happen before all +// writes. +// +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// back-edges (i.e. the edges from NextIteration to Merge). +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// functional control flow nodes containing resource operations. +// +// If `resource_ops_to_ignore` is set then nodes for which it returns true are +// ignored (we pretend these nodes are not resource operations). +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, + std::vector<std::pair<int, int>>* result); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc new file mode 100644 index 0000000000..e54b547abc --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -0,0 +1,540 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +Node* MakeModify(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f); + ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id), + var_handle, value_to_write); + return assign_add_op.operation.node(); +} + +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} + +Status ComputeIncompatiblePairs(Graph* g, + std::vector<std::pair<int, int>>* result) { + FixupSourceAndSinkEdges(g); + return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {}, + result); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) { + Scope root = Scope::NewRootScope().ExitOnError(); + + MakeRead(root, "R"); + MakeWrite(root, "W"); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(read, modify); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(modify, read); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> modify_read_pair = {modify->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(modify, write); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> modify_write_pair = {modify->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> write_modify_pair = {write->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, modify); + root.graph()->AddControlEdge(modify, write); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 2); + std::pair<int, int> modify_write_pair = {modify->id(), write->id()}; + std::pair<int, int> read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + root.graph()->AddControlEdge(modify, read); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair<int, int> write_modify_pair = {write->id(), modify->id()}; + std::pair<int, int> modify_read_pair = {modify->id(), read->id()}; + std::pair<int, int> write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, modify); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair<int, int> write_modify_pair = {write->id(), modify->id()}; + std::pair<int, int> write_read_pair = {write->id(), read->id()}; + std::pair<int, int> read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, + Status* status) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + return graph->AddNode(call_node, status); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, read); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> call_read_edge = {call->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], call_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(read, call); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> read_call_edge = {read->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], read_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, write); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> call_write_edge = {call->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], call_write_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(write, call); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> write_call_edge = {write->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], write_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(symbolic_gradient, read); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(), + read->id()}; + EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(write, symbolic_gradient); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair<int, int> write_symbolic_gradient_edge = {write->id(), + symbolic_gradient->id()}; + EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 5); + std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()}; + std::pair<int, int> write_0_write_1_pair = {write_0->id(), write_1->id()}; + std::pair<int, int> read_0_read_1_pair = {read_0->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + root.graph()->AddControlEdge(write_1, read_1); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT); + Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL); + Output enter_value = + ops::internal::Enter(root.WithOpName("enter"), init_value, "fr"); + ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName("exit"), iv.output); + Output next_iteration = + ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true); + TF_ASSERT_OK( + root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)); + + Node* write = MakeWrite(root, "W"); + Node* read = MakeRead(root, "R"); + + root.graph()->AddControlEdge(iv.output.node(), write); + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, next_iteration.node()); + + std::vector<std::pair<int, int>> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + + std::pair<int, int> write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 38adacd93b..4f2fabd658 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include <unordered_map> +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -207,4 +208,27 @@ bool HasResourceInputOrOutput(const Node& node) { void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } + +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, + GraphCycles* cycles) { + std::vector<std::pair<int, int>> unsafe_deps; + TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( + *graph, flib_def, resource_ops_to_ignore, &unsafe_deps)); + + // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are + // operations that interact with resource variables, must not be put in the + // same cluster. We enforce this constraint by creating a phantom node, X, + // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P + // and Q together since that would create a cycle with X. + + for (std::pair<int, int> unsafe_dep : unsafe_deps) { + int phantom_node_id = cycles->NewNode(); + CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id)); + CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second)); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 662a53d89e..b0439a63ca 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -55,6 +55,13 @@ void RemoveFromXlaCluster(NodeDef* node_def); // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); +// Adds edges to `cycles` to prevent clustering resource operations that cannot +// be legally clustered. +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, + GraphCycles* cycles); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 2cb351e1ec..65bbf3efe8 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7140d47a94..ef6b0e67d3 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile( const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { return CompileImpl(options, function, constant_args, variable_args, ctx, compilation_result, executable, compile_options, false); } @@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp( const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); @@ -256,7 +256,7 @@ Status XlaCompilationCache::CompileImpl( const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op) { CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); @@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl( entry->compiled = true; if (compile_single_op) { - entry->compilation_status = compiler.CompileSingleOp( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - signature.name, ctx, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileSingleOp(compile_options, signature.name, ctx, args, + &entry->compilation_result); } else { entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + compile_options, function, args, &entry->compilation_result); } TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index fc5f008f4f..10ad87e38c 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase { const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op); // Takes `result` which has been compiled from a Tensorflow subgraph to a diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index dd84fb34c1..3ba48e8c31 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -177,7 +177,7 @@ Status XlaCompileOnDemandOp::Compile( std::map<int, OptionalTensor> variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, &compile_options); + result, executable, compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 2027ec7737..ee07c5c964 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -184,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback - // to avoid a deadlock. If done() is the callback that ends an - // Executor's run, the Executor may call XlaDevice::Sync() inside the - // callback. This deadlocks, because XlaDevice::Sync() waits for all - // stream activity to complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); - return; - } } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); @@ -208,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, host_to_device_stream_.get(), block_status.error_message().c_str()); } } - xla_tensor->set_host_tensor(*cpu_tensor); - + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + } done(status); } diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 4b499b1613..915c5afa79 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -208,6 +208,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); // TODO(hpucha): Make clustering more robust. There are two known issues that // we need to mitigate: (a) Non-resource variables can cause deadlocks diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc index 5736760a87..b77b207908 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } +TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output var_handle = + ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({})); + Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f); + Output begin = ops::Const(root.WithOpName("begin"), 0); + Output end = ops::Const(root.WithOpName("end"), 1); + Output strides = ops::Const(root.WithOpName("strides"), 1); + ops::ResourceStridedSliceAssign assign_1( + root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign); + ops::ResourceStridedSliceAssign assign_2( + root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign); + root.graph()->AddControlEdge(assign_1.operation.node(), + assign_2.operation.node()); + grappler::GrapplerItem item; + root.graph()->ToGraphDef(&item.graph); + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_NE(clusters["assign_1"], clusters["assign_2"]); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 235bef07b3..94e08b6efe 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1191,3 +1191,19 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) + +tf_xla_py_test( + name = "xla_ops_test", + size = "small", + srcs = ["xla_ops_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 4a281c37e4..ed4940f204 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1372,5 +1372,40 @@ class BinaryOpsTest(xla_test.XLATestCase): [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], dtype=dtype)) + def testBroadcastTo(self): + for dtype in self.all_types: + x = np.random.randint(0, high=100, size=[2, 3]) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([2, 3], dtype=np.int32), + expected=x) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([6, 6], dtype=np.int32), + expected=np.tile(x, [3, 2])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 4, 3], dtype=np.int32), + expected=np.tile(x, [7, 2, 1])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 0, 3], dtype=np.int32), + expected=np.zeros([7, 0, 3], dtype=dtype)) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 1, 2, 9], dtype=np.int32), + expected=np.tile(x, [7, 1, 1, 3])) + self._testBinary( + array_ops.broadcast_to, + np.zeros([2, 0], dtype=dtype), + np.array([4, 0], dtype=np.int32), + expected=np.zeros([4, 0], dtype=dtype)) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 1a82fcbb2a..6fe5a66e0e 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): image_np, target_shape, expected=None, - large_tolerance=False): + large_tolerance=False, + align_corners=True): if expected is None: self.fail("expected must be specified") with self.cached_session() as sess, self.test_scope(): image = array_ops.placeholder(image_np.dtype) resized = gen_image_ops.resize_bilinear( - image, target_shape, align_corners=True) + image, target_shape, align_corners=align_corners) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) if large_tolerance: self.assertAllClose( @@ -579,6 +580,27 @@ class ResizeBilinearTest(xla_test.XLATestCase): dtype=np.float32)), large_tolerance=True) + def testNonAlignCorners3x2To6x4(self): + input_data = [[64, 32], [32, 64], [50, 100]] + expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0], + [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0], + [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [6, 4], + expected=np.array(expected_data, dtype=np.float32), + align_corners=False) + + def testNonAlignCorners6x4To3x2(self): + input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127], + [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]] + expected_data = [[127, 64], [64, 127], [50, 100]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [3, 2], + expected=np.array(expected_data, dtype=dtype), + align_corners=False) + class NonMaxSuppressionTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py new file mode 100644 index 0000000000..b2f026df6c --- /dev/null +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA op wrappers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected, + equality_fn=None): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def testAdd(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.add, + args=(np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype)), + expected=np.array([5, 7, 9], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(0,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 9], [14, 15]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(1,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 13], [10, 15]], dtype=dtype)) + + def testBroadcast(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.broadcast(x, (7, 42)), + args=(v,), + expected=np.tile(v, (7, 42, 1, 1))) + + def testShiftRightLogical(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + + def testShiftRightArithmetic(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([-1, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) + + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, + xla_data_pb2.PrecisionConfigProto.HIGH, + xla_data_pb2.PrecisionConfigProto.HIGHEST) + + @parameterized.parameters(*PRECISION_VALUES) + def testConv(self, precision): + for dtype in set(self.float_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + def conv_1d_fn(lhs, rhs): + dnums = xla_data_pb2.ConvolutionDimensionNumbers() + num_spatial_dims = 1 + dnums.input_batch_dimension = 0 + dnums.input_feature_dimension = 1 + dnums.output_batch_dimension = 0 + dnums.output_feature_dimension = 1 + dnums.kernel_output_feature_dimension = 0 + dnums.kernel_input_feature_dimension = 1 + dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.conv( + lhs, + rhs, + window_strides=(1,), + padding=((2, 1),), + lhs_dilation=(1,), + rhs_dilation=(2,), + dimension_numbers=dnums) + + self._assertOpOutputMatchesExpected( + conv_1d_fn, + args=( + np.array([[[3, 4, 5, 6]]], dtype=dtype), + np.array([[[-2, -3]]], dtype=dtype), + ), + expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype)) + + @parameterized.parameters(*PRECISION_VALUES) + def testDotGeneral(self, precision): + for dtype in self.float_types: + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.dot_general( + lhs, + rhs, + dimension_numbers=dnums, + precision_config=precision_config) + + lhs = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=dtype) + rhs = np.array( + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], dtype=dtype) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array( + [ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=dtype)) + + def testNeg(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.neg, + args=(np.array([1, 2, 3], dtype=dtype),), + expected=np.array([-1, -2, -3], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 0]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), + expected=np.array( + [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], + dtype=dtype)) + + def testReduce(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + def sum_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4])) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([12, 15, 18, 21], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([6, 22, 38], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0, 1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=dtype(66)) + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + def mul_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + mul_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([0, 45, 120, 231], dtype=dtype)) + + def testSelectAndScatter(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def add_scatter(x, y): + return x + y + + @function.Defun(dtype, dtype) + def ge_select(x, y): + return x >= y + + def test_fn(operand, source): + return xla.select_and_scatter( + operand, + window_dimensions=[2, 3, 1, 1], + window_strides=[2, 2, 1, 1], + padding=[[0, 0]] * 4, + source=source, + init_value=0, + select=ge_select, + scatter=add_scatter) + + self._assertOpOutputMatchesExpected( + test_fn, + args=(np.array( + [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6], + [0, 6, 2, 10, 2]], + dtype=dtype).reshape((4, 5, 1, 1)), + np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))), + expected=np.array( + [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0], + [0, 0, 0, 1, 0]], + dtype=dtype).reshape((4, 5, 1, 1))) + + def testTranspose(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 85fd0c9217..92e577bb7b 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -88,6 +89,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -221,13 +223,11 @@ cc_library( srcs = [ "literal_util.cc", "shape_util.cc", - "str_util.cc", "type_util.cc", ], hdrs = [ "literal_util.h", "shape_util.h", - "str_util.h", "type_util.h", ], visibility = [":friends"], @@ -256,6 +256,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -307,6 +308,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -374,19 +376,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - ], -) - -tf_cc_test( - name = "str_util_test", - srcs = [ - "str_util_test.cc", - ], - deps = [ - ":common", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -459,6 +449,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -482,6 +473,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:graph", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -609,3 +601,30 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "resource_operation_table", + srcs = ["resource_operation_table.cc"], + hdrs = ["resource_operation_table.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "resource_operation_table_test", + srcs = ["resource_operation_table_test.cc"], + deps = [ + ":resource_operation_table", + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index de1008803d..e8673d7790 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { - // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector<bool>* compile_time_const_args) { + std::vector<bool>* compile_time_const_args, + std::vector<bool>* compile_time_const_nodes) { // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set<string> metadata_ops = { "Rank", @@ -36,9 +36,16 @@ Status BackwardsConstAnalysis(const Graph& g, "Size", }; + std::vector<bool> compile_time_const_nodes_impl; + if (compile_time_const_nodes) { + CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); + } else { + compile_time_const_nodes_impl.resize(g.num_node_ids()); + compile_time_const_nodes = &compile_time_const_nodes_impl; + } + Status status; - std::unordered_set<const Node*> must_be_const; - auto visit = [&status, &metadata_ops, &must_be_const, + auto visit = [&status, &metadata_ops, compile_time_const_nodes, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -47,17 +54,19 @@ Status BackwardsConstAnalysis(const Graph& g, // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. - if (must_be_const.find(node) != must_be_const.end()) { + if ((*compile_time_const_nodes)[node->id()]) { if (node->type_string() == "_Arg") { int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - compile_time_const_args->at(index) = true; + if (compile_time_const_args) { + (*compile_time_const_args)[index] = true; + } return; } for (const Edge* pred : node->in_edges()) { if (!pred->IsControlEdge()) { - must_be_const.insert(pred->src()); + (*compile_time_const_nodes)[pred->src()->id()] = true; } } return; @@ -80,7 +89,7 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && edge->dst_input() < name_range->second.second) { - must_be_const.insert(edge->src()); + (*compile_time_const_nodes)[edge->src()->id()] = true; } } } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 634b97d7e3..af57e5a403 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -23,10 +23,18 @@ limitations under the License. namespace tensorflow { -// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that -// must be compile-time constants. +// Backwards dataflow analysis that finds nodes in a graph that must be +// compile-time constants for us to be able to lower the graph to XLA. +// +// The indices of the arguments to `graph` that must be constant are returned in +// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not +// null. +// +// The ids of the nodes in `graph` that must be constant are returned in +// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. Status BackwardsConstAnalysis(const Graph& graph, - std::vector<bool>* compile_time_const_args); + std::vector<bool>* compile_time_const_arg_indices, + std::vector<bool>* compile_time_const_nodes); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 992b12c06d..56065be894 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) { auto c = ops::Reshape(root, arg2, b); auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3)); - Graph graph(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph)); + FixupSourceAndSinkEdges(root.graph()); std::vector<bool> const_args(4, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + std::vector<bool> const_nodes(root.graph()->num_node_ids(), false); + TF_ASSERT_OK( + BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes)); // Arg 0 doesn't need to be constant since the graph only uses its shape. // Arg 1 must be constant because it flows to the shape argument of a Reshape. // Arg 2 is used only as the value input to a Reshape and need not be const. // Arg 3 is used as the reduction-indices argument to Sum and must be const. EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true})); + + EXPECT_FALSE(const_nodes[arg0.node()->id()]); + EXPECT_TRUE(const_nodes[arg1.node()->id()]); + EXPECT_FALSE(const_nodes[arg2.node()->id()]); + EXPECT_TRUE(const_nodes[arg3.node()->id()]); } // Regression test for a case where the backward const analysis did @@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector<bool> const_args(3, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector<bool>({true, true, false})); } @@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector<bool> const_args(2, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector<bool>({false, true})); } diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index f14cfca4ea..b5667ca0d3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -52,11 +53,10 @@ string DebugString(CondStateMap::CondId cond_state) { if (cond_state == nullptr || cond_state->empty()) return "[]"; return strings::StrCat( "[", - tensorflow::str_util::Join( - *cond_state, ", ", - [](string* output, const CondStateMap::CondNode& node) { - strings::StrAppend(output, node.ToString()); - }), + absl::StrJoin(*cond_state, ", ", + [](string* output, const CondStateMap::CondNode& node) { + strings::StrAppend(output, node.ToString()); + }), "]"); } @@ -169,10 +169,10 @@ using CondArgNodes = std::vector<CondArgNode>; string DebugString(const CondArgNodes& nodes) { return strings::StrCat( "[", - tensorflow::str_util::Join(nodes, ", ", - [](string* output, const CondArgNode& node) { - strings::StrAppend(output, node.ToString()); - }), + absl::StrJoin(nodes, ", ", + [](string* output, const CondArgNode& node) { + strings::StrAppend(output, node.ToString()); + }), "]"); } @@ -387,8 +387,9 @@ Status Conditional::BuildArgumentNodes() { } if (!has_input) { return errors::Internal( - "Failed to functionalize control flow with merge '", m->name(), - "' that doesn't have input on ", Branch_Name(branch), " branch."); + "Failed to functionalize control flow with merge ", + FormatNodeForError(*m), " that doesn't have input on ", + Branch_Name(branch), " branch."); } } } @@ -469,8 +470,8 @@ Status Conditional::ExtractBodies(Graph* graph) { // but revisit to improve the testing to enable making this an // error. LOG(WARNING) << errors::InvalidArgument( - "Graph contains node ", src->name(), " that feeds into node ", - dst->name(), + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), " but these nodes are in different control contexts (", DebugString(src_id), " vs ", DebugString(dst_id), " (detected during out edge testing)"); @@ -512,8 +513,8 @@ Status Conditional::ExtractBodies(Graph* graph) { node_map.at(src->id()) = output->CopyNode(src); } else { return errors::InvalidArgument( - "Graph contains node ", src->name(), " that feeds into node ", - dst->name(), + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), " but these nodes are in different control contexts (", DebugString(src_id), " vs ", DebugString(dst_id), " (detected during in edge testing)"); @@ -675,7 +676,8 @@ Status Conditional::AddOutputEdges(Graph* graph) { int dst_input = edge->dst_input(); if (edge->src_output() > 0) { return errors::Unimplemented("Output of index (", edge->src_output(), - ") of merge node ", node->name()); + ") of merge node ", + FormatNodeForError(*node)); } bool control_edge = edge->IsControlEdge(); @@ -1060,7 +1062,8 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { CondStateMap::CondId prop = StateAlongEdge(e); auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); cond_state_map_.ResetId(dst, id_or.ValueOrDie()); } @@ -1090,7 +1093,8 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) { // Joining the state between the current and propagated state. CondStateMap::CondId prop = StateAlongEdge(e); auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); cond_state_map_.ResetId(dst, id_or.ValueOrDie()); } } @@ -1117,7 +1121,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { } if (non_dead_edge == nullptr) { - return errors::InvalidArgument("Merge node ", node->name(), + return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), " has no non-dead inputs."); } cond_state_map_.MarkDead(node); @@ -1169,7 +1173,8 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { if (IsMerge(dst_node)) { auto id_or = JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); - TF_RETURN_IF_ERROR(id_or.status()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst_node)); cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); } else { auto id_or = diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index a0544b69e9..61940e3586 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/graph/graph.h" @@ -43,11 +44,11 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index); template <typename T> string NodesToString(const T& nodes) { return strings::StrCat("{", - str_util::Join(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), "}"); } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e4fdf0a618..ba37ed3337 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -57,7 +57,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, std::vector<bool> compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); for (int i = 0; i < args->size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b1366e9e31..c1438f893f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -22,6 +22,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "broadcast_to_op.cc", "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", @@ -100,6 +101,12 @@ tf_kernel_library( "unary_ops.cc", "unpack_op.cc", "variable_ops.cc", + "xla_broadcast_helper_op.cc", + "xla_conv_op.cc", + "xla_dot_op.cc", + "xla_pad_op.cc", + "xla_reduce_op.cc", + "xla_select_and_scatter_op.cc", ], hdrs = [ "index_ops.h", @@ -108,6 +115,8 @@ tf_kernel_library( deps = [ ":if_op", ":while_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ba3b1c9dab..2e383b1473 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); const int64 len = bcast.output_shape().size(); Tensor output(DT_INT32, TensorShape({len})); @@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); Output(ctx, 0, bcast.grad_x_reduce_idx()); Output(ctx, 1, bcast.grad_y_reduce_idx()); } diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc new file mode 100644 index 0000000000..4bd7c74dca --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { +namespace { + +class BroadcastToOp : public XlaOpKernel { + public: + explicit BroadcastToOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + + OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), + errors::InvalidArgument( + "Input rank (", input_shape.dims(), + ") must be less than or equal to the output rank (", + output_shape.dims(), ")")); + + auto input_dims = input_shape.dim_sizes(); + auto output_dims = output_shape.dim_sizes(); + + // Broadcasting is done right-to-left on right-aligned dimensions; reverse + // the two vectors so elements to be broadcast are aligned. + absl::c_reverse(input_dims); + absl::c_reverse(output_dims); + + std::vector<int64> broadcast_dims; + std::vector<int64> broadcast_shape; + for (int i = 0; i < output_shape.dims(); ++i) { + if (i < input_shape.dims()) { + OP_REQUIRES( + context, + (output_dims[i] == 0 && input_dims[i] == 0) || + (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), + errors::InvalidArgument("invalid shape to broadcast from ", + input_shape.DebugString(), " to ", + output_shape.DebugString())); + + broadcast_dims.push_back(broadcast_shape.size()); + if (output_dims[i] == input_dims[i] || input_dims[i] == 1) { + broadcast_shape.push_back(output_dims[i]); + } + if (output_dims[i] != input_dims[i]) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(input_dims[i]); + broadcast_shape.push_back(output_dims[i] / input_dims[i]); + } + } else { + broadcast_shape.push_back(output_dims[i]); + } + } + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::Reshape( + xla::BroadcastInDim(context->Input(0), + xla::ShapeUtil::MakeShape( + context->input_xla_type(0), broadcast_shape), + broadcast_dims), + output_shape.dim_sizes()); + context->SetOutput(0, output); + } +}; + +REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"), + BroadcastToOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8d75624e74..8e071bf0b7 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -32,13 +32,13 @@ namespace { // // 1. S := (N - 1) / gcd(N-1, R-1) // 2. k := (R - 1) / gcd(N-1, R-1) -// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1) // // For example, to Scale from 7x7 -> 15x15: // // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 -// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2) // // // The 7x7 -> 15x15 case is much too large to write out in full as an @@ -65,6 +65,8 @@ namespace { // 1/9 * 3 6 9 6 3 // 2 4 6 4 2 // 1 2 3 2 1 +// Note that the convolution kernel matrix is separable and thus we can instead +// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis. // Computes the size of the convolutional kernel and stride to use when resizing // from in_size to out_size. @@ -76,7 +78,8 @@ struct ResizeConvolutionDims { std::vector<int64> stride; }; ResizeConvolutionDims ComputeResizeConvolutionParameters( - gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size) { + gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size, + bool align_corners) { CHECK_EQ(in_size.size(), out_size.size()); int num_spatial_dims = in_size.size(); ResizeConvolutionDims dims; @@ -92,15 +95,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // entry before resizing. dims.stride[i] = dims.kernel_size[i] = 1; } else { - int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size[i] - 1), - static_cast<uint64>(out_size[i] - 1)); - dims.stride[i] = (in_size[i] - 1) / gcd; - dims.kernel_size[i] = (out_size[i] - 1) / gcd; + // The scaling factor changes depending on the alignment of corners. + const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i]; + const int64 out_size_factor = + align_corners ? out_size[i] - 1 : out_size[i]; + + int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor), + static_cast<uint64>(out_size_factor)); + dims.stride[i] = in_size_factor / gcd; + dims.kernel_size[i] = out_size_factor / gcd; } } return dims; } +// The upper padding of the input needed by ConvGeneralDilated calls is +// determined by solving two related relationships (assuming rhs_dilation == 0): +// 1. dilated_input_dim = lower_padding + upper_padding +// + lhs_dilation * (in_size - 1) + 1 +// 2. dilated_input_dim = (2 * dims.kernel-size - 1) +// + dims.stride * (out_size - 1) +int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, + int64 stride) { + return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - + 1 - (kernel_size * (in_size - 1)); +} + // Form a 2D convolution kernel like: // 1 2 3 2 1 // 2 4 6 4 2 @@ -171,7 +191,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector<int64> in_size, std::vector<int64> out_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -196,27 +217,82 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, out_size); + ComputeResizeConvolutionParameters(in_size, out_size, align_corners); xla::XlaOp output; - // Split convolutions into independent dimensions if they wmuld be a very + + // Concatenation and padding below currently assumes num_spatial_dims is 2 to + // prevent needless code complexity. + CHECK_EQ(num_spatial_dims, 2) + << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently."; + std::vector<int64> upper_padding(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + upper_padding[i] = dims.kernel_size[i] - 1; + } + xla::XlaOp input_data = input; + + if (!align_corners) { + // When Tensorflow does not align_corners, the resize indexing can access + // beyond the upper bound and is instead clamped to prevent out of bounds + // reads. This is conceptually the same as extending the edges of the input. + // We emulate this by copying the last row/column of the input. + // Calculate what padding would be needed then determine how far to extend + // the border before lhs dilation. + std::vector<int64> num_extended(num_spatial_dims); + upper_padding[0] = CalculateUpperPadding( + in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = CalculateUpperPadding( + in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]); + num_extended[0] = upper_padding[0] / (dims.kernel_size[0]); + num_extended[1] = upper_padding[1] / (dims.kernel_size[1]); + + if (num_extended[0] > 0) { + auto slice = + xla::Slice(input_data, {0, in_size[0] - 1, 0, 0}, + {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + for (int i = 0; i < num_extended[0]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 1); + } + } + + if (num_extended[1] > 0) { + auto slice = + xla::Slice(input_data, {0, 0, in_size[1] - 1, 0}, + {1, in_size[0] + num_extended[0], in_size[1], channels}, + {1, 1, 1, 1}); + for (int i = 0; i < num_extended[1]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 2); + } + } + + // Setting in_size to (in_size + num_extended) due to the above Slice and + // ConcatInDim. Recalculate needed padding after the above Slice/Concat. + upper_padding[0] = + CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0], + dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = + CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1], + dims.kernel_size[1], dims.stride[1]); + } + + // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = xla::ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = + xla::ConvGeneralDilated(input_data, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, upper_padding[0]}, + {dims.kernel_size[1] - 1, upper_padding[1]}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); output = xla::ConvGeneralDilated( - input, kernel0, {dims.stride[0], 1}, + input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers); xla::XlaOp kernel1 = @@ -224,7 +300,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ - {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, /*rhs_dilation=*/{1, 1}, dimension_numbers); } @@ -245,9 +321,10 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector<int64> in_size, std::vector<int64> grad_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, grad_size); + ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); // To form the backward convolution, we keep the kernel unchanged (it is // already symmetric) and swap the roles of strides and LHS dilation. @@ -341,10 +418,6 @@ class ResizeBilinearOp : public XlaOpKernel { public: explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); - OP_REQUIRES( - ctx, align_corners_ == true, - errors::Unimplemented( - "ResizeBilinear with align_corners=False is not yet implemented")); } void Compile(XlaOpKernelContext* ctx) override { @@ -377,20 +450,19 @@ class ResizeBilinearOp : public XlaOpKernel { // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. - std::vector<int64> slice_size = in_size; bool slice_input = false; for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] > 1 && out_size[i] == 1) { // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first // entry before resizing. slice_input = true; - slice_size[i] = 1; + in_size[i] = 1; } } if (slice_input) { - input = xla::Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = + xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } // Output is always type float. @@ -406,6 +478,9 @@ class ResizeBilinearOp : public XlaOpKernel { // operations along different dimensions. // Given sufficient numerical stability and a<e<c and b<f<d, bilinear resize // from image of size axb -> cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. // // This makes the convolutions kernels smaller and the operation faster. xla::XlaOp output = input; @@ -415,21 +490,24 @@ class ResizeBilinearOp : public XlaOpKernel { (static_cast<float>(out_size[0]) - 1) / ((in_size[0] - 1) * 2), (static_cast<float>(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1) { + k[0] > 1 && k[1] > 1 && align_corners_) { std::vector<int64> next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, next_out_size, + channels, align_corners_); input = output; in_size = next_out_size; } else { - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, out_size, + channels, align_corners_); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels); + in_size, out_size, channels, + align_corners_); in_size = out_size; } } @@ -509,17 +587,20 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels); + b, grad, num_spatial_dims, in_size, next_grad_size, channels, + align_corners_); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4d180aff8..f6f158a73b 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Divide each element of an image by the count of elements that contributed to -// that element during pooling. -static xla::XlaOp AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape, xla::Padding padding, - const std::vector<int64>& ksize, const std::vector<int64>& stride, - int num_spatial_dims, TensorFormat data_format) { - if (padding == xla::Padding::kValid) { - // In VALID padding, all windows have the same number of elements - // contributing to each average. Divide by the window size everywhere to - // get the average. - int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, - [](int64 a, int64 b) { return a * b; }); - - auto divisor = - XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return xla::Div(output, divisor); - } else { - // For SAME padding, the padding shouldn't be included in the - // counts. We use another ReduceWindow to find the right counts. - - // TODO(phawkins): use a less brute-force way to compute this. Only - // the boundary regions will have interesting values here. - - std::vector<int64> input_dim_sizes(num_spatial_dims); - std::vector<int64> window_dims(num_spatial_dims); - std::vector<int64> window_ksize(num_spatial_dims); - std::vector<int64> window_stride(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); - input_dim_sizes[i] = input_shape.dim_size(dim); - window_dims[i] = dim; - window_ksize[i] = ksize[dim]; - window_stride[i] = stride[dim]; - } - - // Build a matrix of all 1s, with the same width/height as the input. - const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = xla::Broadcast( - XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); - - // Perform a ReduceWindow with the same window size, strides, and padding - // to count the number of contributions to each result element. - auto reduce = xla::ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, - xla::Padding::kSame); - auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - - return xla::Div(output, counts, window_dims); - } -} - class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) @@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel { errors::InvalidArgument("out_backprop must be ", num_dims(), "-dimensional")); - int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - int64 depth = out_backprop_shape.dim_size(depth_dim); - - // We can think of average-pooling as: - // * a convolution with a kernel consisting entirely of 1s, where the - // input feature and output feature are equal, and 0s everywhere else. - // * followed by dividing by the counts. - // - // This then gives us an algorithm to build the gradient: - // * divide out_backprop by the counts, followed by - // * Conv2DBackpropInput specialized for that kernel, which simplifies to - // a Pad and a ReduceWindow. - // - // For an explanation of backpropagation for convolution, see the comments - // in third_party/tensorflow/core/kernels/conv_grad_ops.h - - // TF filter shape is [ H, W, ..., inC, outC ] - std::vector<int64> filter_dims(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - filter_dims[i] = ksize_[dim]; - } - filter_dims[num_dims() - 2] = depth; - filter_dims[num_dims() - 1] = depth; - TensorShape filter_shape(filter_dims); - - // Reuse the logic from Conv2DBackpropInput to compute padding. - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - ctx, ConvBackpropComputeDimensions( - type_string(), /*num_spatial_dims=*/num_spatial_dims_, - gradients_shape, filter_shape, out_backprop_shape, stride_, - padding_, data_format_, &dims)); - - // The input gradients are computed by a convolution of the output gradients - // and the filter, with some appropriate padding. See the comment at the top - // of conv_grad_ops.h for details. - xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - auto dtype = input_type(1); + std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - // Divide the out_backprop values by the counts for each spatial position. - std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); - auto out_backprop_div = AvgPoolDivideByCount( - ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, - stride_int64s, num_spatial_dims_, data_format_); - - // Pad the gradients in the spatial dimensions. We use the same padding - // as Conv2DBackpropInput. - xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - auto* padding = padding_config.mutable_dimensions(dim); - padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); - padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); - padding->set_interior_padding(dims.spatial_dims[i].stride - 1); - } - - auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); - - // in_backprop = padded_gradients <conv> ones - std::vector<int64> ones(num_dims(), 1LL); - auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = xla::ReduceWindow( - XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), ksize_, - /* window_strides=*/ones, xla::Padding::kValid); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); + xla::PrimitiveType xla_reduction_type; + auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type)); + auto converted_out_backprop = + xla::ConvertElementType(out_backprop, xla_reduction_type); + auto xla_data_format = + XlaTensorFormat(data_format_, gradients_shape.dims() - 2); + auto padding_values = + MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s, + xla_padding, xla_data_format); + auto in_backprop = + xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(), + ksize_, stride_int64s, padding_values, xla_data_format, + /*counts_include_padding=*/padding_ == VALID); + // Convert the pooling result back to the input type before returning it. + xla::PrimitiveType xla_out_backprop_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), + &xla_out_backprop_type)); + ctx->SetOutput(0, + xla::ConvertElementType(in_backprop, xla_out_backprop_type)); } protected: diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index b11a4ce36d..8102faad28 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel { explicit ReduceWindowOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_dimensions", &window_dimensions_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_strides", &window_strides_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); const DataType dtype = context->input_type(0); + std::vector<int64> window_dimensions; + std::vector<int64> window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + const int rank = input_shape.dims(); - OP_REQUIRES(context, rank == window_dimensions_.size(), + OP_REQUIRES(context, rank == window_dimensions.size(), errors::InvalidArgument( "The size of window_dimensions must be equal to the input " "rank (", - window_dimensions_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == window_strides_.size(), + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), errors::InvalidArgument( "The size of window_strides must be equal to the input " "rank (", - window_strides_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_low_.size(), - errors::InvalidArgument( - "The size of padding_low must be equal to the input " - "rank (", - padding_low_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_high_.size(), - errors::InvalidArgument( - "The size of padding_high must be equal to the input " - "rank (", - padding_high_.size(), " vs. ", rank, ")")); - - xla::XlaBuilder* builder = context->builder(); + window_strides.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel { compile_options.use_tuple_arg = false; compile_options.resolve_compile_time_constants = false; compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; XlaCompiler::CompilationResult reducer; OP_REQUIRES_OK(context, context->compiler()->CompileFunction( compile_options, *computation_, @@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel { xla::Shape scalar_shape; OP_REQUIRES_OK(context, TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); OP_REQUIRES(context, - xla::ShapeUtil::Compatible( - reducer.xla_output_shape, - xla::ShapeUtil::MakeTupleShape({scalar_shape})), + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, errors::InvalidArgument( - "Invalid output shape of ReduceWindow reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", - xla::ShapeUtil::HumanString(reducer.xla_output_shape))); - - // Wraps the reducer in a computation that unpacks the output tuple. - xla::XlaComputation wrapper; - { - std::unique_ptr<xla::XlaBuilder> cb = - builder->CreateSubBuilder("wrapper"); - auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); - auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); - auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); - xla::GetTupleElement(outputs, 0); - xla::StatusOr<xla::XlaComputation> result = cb->Build(); - OP_REQUIRES_OK(context, result.status()); - wrapper = std::move(result.ValueOrDie()); - } - - std::vector<std::pair<int64, int64>> padding(rank); - for (int i = 0; i < rank; ++i) { - padding[i] = {padding_low_[i], padding_high_[i]}; + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get<int64>({i, 0}), + padding_literal.Get<int64>({i, 1})}; } xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( - context->Input(0), context->Input(1), wrapper, window_dimensions_, - window_strides_, padding); + context->Input(0), context->Input(1), *reducer.computation, + window_dimensions, window_strides, padding); context->SetOutput(0, output); } private: const NameAttrList* computation_; - std::vector<int64> window_dimensions_; - std::vector<int64> window_strides_; - std::vector<int64> padding_low_; - std::vector<int64> padding_high_; TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); }; -REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); +REGISTER_XLA_OP(Name("XlaReduceWindow") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + ReduceWindowOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 6a71b8ca36..598248563b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reduction Ops. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -66,7 +67,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << str_util::Join(axes, ","); + VLOG(1) << "axes : " << absl::StrJoin(axes, ","); gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false); std::vector<int64> xla_axes; diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 025ba82741..d6bd927135 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -33,7 +33,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = str_util::StartsWith(type_string(), "Log"); + log_ = absl::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc new file mode 100644 index 0000000000..412afeaaad --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +class XlaBroadcastHelperOp : public XlaOpKernel { + public: + explicit XlaBroadcastHelperOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp lhs = context->Input(0); + xla::XlaOp rhs = context->Input(1); + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); + const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; + const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; + + std::vector<int64> broadcast_dims; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims", + &broadcast_dims)); + if (broadcast_dims.empty()) { + OP_REQUIRES( + context, + lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || + rhs_shape.dims() == 0, + errors::InvalidArgument( + "If broadcast_dims is empty, both " + "arguments must have equal rank; " + "argument shapes, or at least one argument must be a scalar: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + return; + } + + OP_REQUIRES( + context, broadcast_dims.size() == min_rank_shape->dims(), + errors::InvalidArgument( + "broadcast_dims must have size equal to the smaller argument rank; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + std::vector<int64> sorted_broadcast_dims = broadcast_dims; + absl::c_sort(sorted_broadcast_dims); + std::set<int64> dims_set(broadcast_dims.begin(), broadcast_dims.end()); + OP_REQUIRES(context, + dims_set.size() == broadcast_dims.size() && + broadcast_dims == sorted_broadcast_dims, + errors::InvalidArgument( + "Duplicate or nonmonotonic dimension in broadcast_dims; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]")); + + std::vector<int64> broadcast_shape(max_rank_shape->dims(), 1LL); + for (int i = 0; i < broadcast_dims.size(); ++i) { + const int dim = broadcast_dims[i]; + OP_REQUIRES( + context, dim >= 0 && dim < broadcast_shape.size(), + errors::InvalidArgument( + "Invalid broadcast dimension (", dim, "); broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + broadcast_shape[dim] = min_rank_shape->dim_size(i); + } + xla::PrimitiveType type = context->input_xla_type(0); + xla::Shape broadcast_xla_shape = + xla::ShapeUtil::MakeShape(type, broadcast_shape); + if (broadcast_lhs) { + lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + } else { + rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + } + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + } + + private: + xla::DotDimensionNumbers dnums_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp); +}; + +REGISTER_XLA_OP( + Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"), + XlaBroadcastHelperOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc new file mode 100644 index 0000000000..8848623868 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaConvOp : public XlaOpKernel { + public: + explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + const TensorShape padding_shape = context->InputShape("padding"); + std::vector<int64> window_strides; + std::vector<int64> lhs_dilation; + std::vector<int64> rhs_dilation; + int64 feature_group_count; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation", + &lhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation", + &rhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar( + "feature_group_count", &feature_group_count)); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get<int64>({i, 0}), + padding_literal.Get<int64>({i, 1})}; + } + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::ConvGeneralDilated( + context->Input(0), context->Input(1), window_strides, padding, + lhs_dilation, rhs_dilation, dnums_, feature_group_count, + &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::ConvolutionDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); +}; + +REGISTER_XLA_OP(Name("XlaConv") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("lhs_dilation") + .CompileTimeConstInput("rhs_dilation") + .CompileTimeConstInput("feature_group_count") + .CompileTimeConstInput("padding"), + XlaConvOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc new file mode 100644 index 0000000000..2fed53e5c0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDotOp : public XlaOpKernel { + public: + explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1), + dnums_, &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::DotDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); +}; + +REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc new file mode 100644 index 0000000000..59502d83c7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaPadOp : public XlaOpKernel { + public: + explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape padding_value_shape = + context->InputShape("padding_value"); + + std::vector<int64> padding_low; + std::vector<int64> padding_high; + std::vector<int64> padding_interior; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low", + &padding_low)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high", + &padding_high)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "padding_interior", &padding_interior)); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape), + errors::InvalidArgument("padding_value must be a scalar")); + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == padding_low.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_interior.size(), + errors::InvalidArgument( + "The size of padding_interior must be equal to the input " + "rank (", + padding_interior.size(), " vs. ", rank, ")")); + + auto non_negative = [](int64 x) { return x >= 0; }; + OP_REQUIRES( + context, absl::c_all_of(padding_low, non_negative), + errors::InvalidArgument("padding_low must be non-negative, got [", + absl::StrJoin(padding_low, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_high, non_negative), + errors::InvalidArgument("padding_high must be non-negative, got [", + absl::StrJoin(padding_high, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_interior, non_negative), + errors::InvalidArgument("padding_interior must be non-negative, got [", + absl::StrJoin(padding_interior, ","), "]")); + + xla::PaddingConfig padding_config; + for (int i = 0; i < rank; ++i) { + auto* dim = padding_config.add_dimensions(); + dim->set_edge_padding_low(padding_low[i]); + dim->set_edge_padding_high(padding_high[i]); + dim->set_interior_padding(padding_interior[i]); + } + + xla::XlaOp output = + xla::Pad(context->Input("input"), context->Input("padding_value"), + padding_config); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp); +}; + +REGISTER_XLA_OP(Name("XlaPad") + .CompileTimeConstInput("padding_low") + .CompileTimeConstInput("padding_high") + .CompileTimeConstInput("padding_interior"), + XlaPadOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc new file mode 100644 index 0000000000..fc2425f37b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaReduceOp : public XlaOpKernel { + public: + explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_)); + OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce", + &dimensions_to_reduce_)); + std::set<int64> dims_set(dimensions_to_reduce_.begin(), + dimensions_to_reduce_.end()); + OP_REQUIRES( + context, dims_set.size() == dimensions_to_reduce_.size(), + errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " + "argument to XlaReduce")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape init_value_shape = context->InputShape("init_value"); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), + errors::InvalidArgument("init_value must be a scalar")); + + auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; + OP_REQUIRES(context, + rank >= dimensions_to_reduce_.size() && + absl::c_all_of(dimensions_to_reduce_, dim_in_range), + errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce")); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *reducer_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of XlaReduce reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + xla::XlaOp output = + xla::Reduce(context->Input("input"), context->Input("init_value"), + *reducer.computation, dimensions_to_reduce_); + context->SetOutput(0, output); + } + + private: + const NameAttrList* reducer_; + std::vector<int64> dimensions_to_reduce_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); +}; + +REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc new file mode 100644 index 0000000000..089776fcf7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSelectAndScatterOp : public XlaOpKernel { + public: + explicit XlaSelectAndScatterOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_)); + OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + std::vector<int64> window_dimensions; + std::vector<int64> window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides.size(), " vs. ", rank, ")")); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; + + // Build the select function. + XlaCompiler::Argument select_arg; + select_arg.kind = XlaCompiler::Argument::kParameter; + select_arg.type = dtype; + select_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult select; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *select_computation_, + {select_arg, select_arg}, &select)); + + xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {}); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(select.xla_output_shape, + select_output_shape), + errors::InvalidArgument( + "Invalid output shape of XlaSelectAndScatter select. Expected ", + xla::ShapeUtil::HumanString(select_output_shape), " got ", + xla::ShapeUtil::HumanString(select.xla_output_shape))); + + // Build the scatter function. + XlaCompiler::Argument scatter_arg; + scatter_arg.kind = XlaCompiler::Argument::kParameter; + scatter_arg.type = dtype; + scatter_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult scatter; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *scatter_computation_, + {scatter_arg, scatter_arg}, &scatter)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of scatter. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(scatter.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get<int64>({i, 0}), + padding_literal.Get<int64>({i, 1})}; + } + + xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding( + context->Input("operand"), *select.computation, window_dimensions, + window_strides, padding, context->Input("source"), + context->Input("init_value"), *scatter.computation); + context->SetOutput(0, output); + } + + private: + const NameAttrList* select_computation_; + const NameAttrList* scatter_computation_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp); +}; + +REGISTER_XLA_OP(Name("XlaSelectAndScatter") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + XlaSelectAndScatterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index cb7a40e23d..99511e9914 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,8 +44,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:lib", ], @@ -78,8 +78,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", @@ -119,6 +119,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index f666d22ea4..d8c050d09e 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -27,7 +27,8 @@ limitations under the License. namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y) { + bool transpose_y, bool conjugate_x, bool conjugate_y, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -95,6 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } + xla::PrecisionConfigProto precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + // If there are no batch dimensions, use a regular Dot. // TODO(b/69062148) Remove this code when Dot emitters can be passed // dimensions to transpose directly (i.e. without requiring a Transpose @@ -102,7 +107,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, if (batch_dimension_numbers.empty()) { auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs); + return xla::Dot(lhs, rhs, &precision_proto); } xla::DotDimensionNumbers dot_dnums; @@ -112,7 +117,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return xla::DotGeneral(x, y, dot_dnums); + + return xla::DotGeneral(x, y, dot_dnums, &precision_proto); }); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 8757b16a1c..6cfccd5553 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -45,7 +45,9 @@ namespace tensorflow { // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false); + bool conjugate_y = false, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 87d73eb3f0..67fb56510c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -49,7 +49,8 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { +xla::XlaOp CholeskyUnblocked(xla::XlaOp a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -101,7 +102,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -121,7 +123,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // r.T) auto dot = BatchDot(body_l, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -145,7 +148,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -181,14 +185,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); - auto factorized = CholeskyUnblocked(x); + auto factorized = CholeskyUnblocked(x, precision); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 1bef9bb166..60cd7ded53 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -30,7 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index fc0c1ee838..b6f30d8d49 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -149,7 +149,8 @@ struct QRBlockResult { xla::XlaOp taus; // Shape: [..., n] xla::XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) { +xla::StatusOr<QRBlockResult> QRBlock( + xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -190,8 +191,12 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) { auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = BatchDot(v_broadcast, a); - vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true); + auto vva = + BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + vva = + BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -251,7 +256,8 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) { // vs. xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n) { + xla::XlaOp taus, int64 m, int64 n, + xla::PrecisionConfigProto::Precision precision) { std::vector<int64> batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; @@ -272,9 +278,12 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true); + auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv); + auto wyv = + BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); auto z = xla::Mul( -beta, v + wyv, @@ -321,8 +330,9 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a, - int64 block_size) { +xla::StatusOr<QRDecompositionResult> QRDecomposition( + xla::XlaOp a, int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -352,29 +362,36 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a, int64 k = std::min(block_size, p - i); auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); - TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block)); + TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); // Compute the I-WY block representation of a product of Householder // matrices. - TF_ASSIGN_OR_RETURN(auto w, - ComputeWYRepresentation(type, batch_dims, qr_block.vs, - qr_block.taus, m - i, k)); + TF_ASSIGN_OR_RETURN( + auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, + qr_block.taus, m - i, k, precision)); auto y = qr_block.vs; // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true); - a_update = BatchDot(y, a_update); + auto a_update = + BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + a_update = + BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w); - q_update = - BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true); + auto q_update = + BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + q_update = BatchDot(q_update, y, /*transpose_x=*/false, + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index abd2316ac9..05565477b6 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -32,8 +33,10 @@ struct QRDecompositionResult { xla::XlaOp r; }; -xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a, - int64 block_size = 128); +xla::StatusOr<QRDecompositionResult> QRDecomposition( + xla::XlaOp a, int64 block_size = 128, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index febb638e5e..37b2240b45 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -110,8 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp InvertDiagonalBlocks( + xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { // Input is a batch of square lower triangular square matrices. Its shape is @@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - auto update = -DotGeneral(input_row, body_out, dnums); + xla::PrecisionConfigProto precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); body_out = DynamicUpdateSlice(body_out, update, start_indices); @@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, - xla::XlaOp inv_diag_blocks, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp SolveWithInvertedDiagonalBlocks( + xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false); + remainder = b_row - BatchDot(a_row, x, transpose_a, false, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a); + remainder = b_row - BatchDot(x, a_row, false, transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } } @@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::ConstantR0WithType(builder, xla::S32, j * block_size); std::vector<xla::XlaOp> update_starts = {start_index, zero}; if (left_side) { - x_update = BatchDot(inv_block, remainder, transpose_a, false); + x_update = + BatchDot(inv_block, remainder, transpose_a, false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); } else { - x_update = BatchDot(remainder, inv_block, false, transpose_a); + x_update = + BatchDot(remainder, inv_block, false, transpose_a, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size) { + int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = - InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, + conjugate_a, precision); // We now find the solution using GEMMs - auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, - lower, transpose_a, conjugate_a); + auto x = + SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, + transpose_a, conjugate_a, precision); return x; }); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 555760b7ef..ac42a48352 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -59,7 +59,9 @@ namespace tensorflow { // blocking is used. xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128); + int64 block_size = 128, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index ace6fd1d8e..4dce0a2102 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -11,6 +11,8 @@ cc_library( srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a59c77f5c3..2cd9ae799f 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,11 +13,97 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace { + +// Helper shape function for operators that return an output with the same rank +// as their first input. +Status UnchangedRank(shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); +} + +REGISTER_OP("XlaBroadcastHelper") + .Input("lhs: T") + .Input("rhs: T") + .Input("broadcast_dims: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("lhs_output: T") + .Output("rhs_output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Helper operator for performing XLA-style broadcasts + +Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +for binary operators. + +lhs: the LHS input tensor +rhs: the RHS input tensor +broadcast_dims: an XLA-style broadcast dimension specification +lhs_output: the broadcasted LHS tensor +rhs_output: the broadcasted RHS tensor +)doc"); + +REGISTER_OP("XlaConv") + .Input("lhs: T") + .Input("rhs: T") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("lhs_dilation: Tindices") + .Input("rhs_dilation: Tindices") + .Input("feature_group_count: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + +lhs: the input tensor +rhs: the kernel tensor +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +lhs_dilation: dilation to apply between input elements +rhs_dilation: dilation to apply between kernel elements +feature_group_count: number of feature groups for grouped convolution. +dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); + +REGISTER_OP("XlaDot") + .Input("lhs: T") + .Input("rhs: T") + .Attr("T: numbertype") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + +lhs: the LHS tensor +rhs: the RHS tensor +dimension_numbers: a serialized xla::DotDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") @@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); +REGISTER_OP("XlaPad") + .Input("input: T") + .Input("padding_value: T") + .Input("padding_low: Tindices") + .Input("padding_high: Tindices") + .Input("padding_interior: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Pad operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#pad +. + +input: A `Tensor` of type T. +padding_value: A scalar `Tensor` of type T. +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +padding_interior: the padding to apply between each input element. +output: A `Tensor` of type T. +)doc"); + REGISTER_OP("XlaRecv") .Output("tensor: dtype") .Attr("dtype: type") @@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel. shape: The shape of the tensor. )doc"); +REGISTER_OP("XlaReduce") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector<int64> dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set<int64> dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + if (rank < dimensions_to_reduce.size() || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce"); + } + c->set_output( + 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn(UnchangedRank) .Doc(R"doc( Wraps the XLA ReduceWindow operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . @@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction computation: a reducer function to apply window_dimensions: the shape of the window window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. +padding: the padding to apply at the start and end of each input dimensions +)doc"); + +REGISTER_OP("XlaSelectAndScatter") + .Input("operand: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("source: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("select: func") + .Attr("scatter: func") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA SelectAndScatter operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter +. + +operand: the input tensor +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +source: a tensor of values to scatter +init_value: a scalar representing the initial value for the output tensor +select: a selection function to apply +scatter: a scatter function to apply )doc"); REGISTER_OP("XlaSend") @@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 42b6292f79..69ca394360 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -28,5 +28,6 @@ py_library( srcs = ["xla.py"], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/xla:xla_data_proto_py", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 2fc47dffb8..3626de375e 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -15,11 +15,12 @@ """Experimental library that exposes XLA operations directly in TensorFlow. It is sometimes useful to be able to build HLO programs directly from -TensorFlow. This file provides Tensorflow operators that map as closely as -possible to HLO operators. +TensorFlow. This file provides Tensorflow operators that mirror the semantics of +HLO operators as closely as possible. -There is no promise of backward or forward compatibility for operators defined -in this module. +Note: There is no promise of backward or forward compatibility for operators +defined in this module. This is primarily because the underlying HLO operators +do not promise backward or forward compatibility. """ from __future__ import absolute_import @@ -27,11 +28,298 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + +# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing +# ops include: +# infeed/outfeed (available via tf.contrib.tpu) +# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) +# conditional +# gather/scatter +# collapse + +# This file reuses builtin names (following XLA's names, so we can call things +# like xla.max), so we capture the builtin versions here. +# pylint: disable=redefined-builtin +_max = max +_min = min +_slice = slice # pylint: disable=invalid-name + +constant = constant_op.constant + +# Unary operators. + +# For most arithmetic operators there is a TensorFlow operator +# that exactly corresponds to each XLA operator. Rather than defining +# XLA-specific variants, we reuse the corresponding TensorFlow operator. +# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 +# wrap every HLO operator, because that would allow us to be confident that the +# semantics match. + + +def _unary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def unary_op_wrapper(x, name=None): + return fn(x, name=name) + + return unary_op_wrapper + + +abs = _unary_op(math_ops.abs) +# TODO(phawkins): implement clz. +conj = _unary_op(math_ops.conj) +cos = _unary_op(math_ops.cos) +ceil = _unary_op(math_ops.ceil) +digamma = _unary_op(math_ops.digamma) +erf = _unary_op(math_ops.erf) +erfc = _unary_op(math_ops.erfc) +# TODO(phawkins): implement erfinv +exp = _unary_op(math_ops.exp) +expm1 = _unary_op(math_ops.expm1) +floor = _unary_op(math_ops.floor) +imag = _unary_op(math_ops.imag) +is_finite = _unary_op(math_ops.is_finite) +lgamma = _unary_op(math_ops.lgamma) +log = _unary_op(math_ops.log) +log1p = _unary_op(math_ops.log1p) +logical_not = _unary_op(math_ops.logical_not) +neg = _unary_op(math_ops.neg) +real = _unary_op(math_ops.real) +# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for +# numbers halfway between two integers. +round = _unary_op(math_ops.round) +sin = _unary_op(math_ops.sin) +sign = _unary_op(math_ops.sign) +tanh = _unary_op(math_ops.tanh) + +# Binary operators + +# The main difference between TensorFlow and XLA binary ops is the broadcasting +# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA +# requires an explicit specification of which dimensions to broadcast if the +# arguments have different ranks. + + +def _broadcasting_binary_op(fn): + """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" + + def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): + """Inner wrapper function.""" + broadcast_dims = broadcast_dims or [] + broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) + # Rather than relying on having static shape information in the TensorFlow + # graph, we use an XlaBroadcastHelper op that can compute the correct shapes + # at JIT compilation time. + x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) + return fn(x, y, name=name) + + return broadcasting_binary_op_wrapper + + +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + dtypes.int8: dtypes.uint8, + dtypes.int16: dtypes.uint16, + dtypes.int32: dtypes.uint32, + dtypes.int64: dtypes.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = { + dtypes.uint8: dtypes.int8, + dtypes.uint16: dtypes.int16, + dtypes.uint32: dtypes.int32, + dtypes.uint64: dtypes.int64, +} + + +def _shift_right_logical_helper(x, y, name=None): + """Performs an integer right logical shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + signed = dtype in _SIGNED_TO_UNSIGNED_TABLE + if signed: + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] + x = math_ops.cast(x, unsigned_dtype) + y = math_ops.cast(y, unsigned_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if signed: + output = math_ops.cast(output, dtype) + return output + + +def _shift_right_arithmetic_helper(x, y, name=None): + """Performs an integer right arithmetic shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE + if unsigned: + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] + x = math_ops.cast(x, signed_dtype) + y = math_ops.cast(y, signed_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if unsigned: + output = math_ops.cast(output, dtype) + return output + + +add = _broadcasting_binary_op(math_ops.add) +sub = _broadcasting_binary_op(math_ops.sub) +mul = _broadcasting_binary_op(math_ops.mul) +div = _broadcasting_binary_op(math_ops.div) +rem = _broadcasting_binary_op(gen_math_ops.mod) +max = _broadcasting_binary_op(math_ops.maximum) +min = _broadcasting_binary_op(math_ops.minimum) +atan2 = _broadcasting_binary_op(math_ops.atan2) +complex = _broadcasting_binary_op(math_ops.complex) +logical_and = _broadcasting_binary_op(math_ops.logical_and) +logical_or = _broadcasting_binary_op(math_ops.logical_or) +logical_xor = _broadcasting_binary_op(math_ops.logical_xor) +eq = _broadcasting_binary_op(math_ops.equal) +ne = _broadcasting_binary_op(math_ops.not_equal) +ge = _broadcasting_binary_op(math_ops.greater_equal) +gt = _broadcasting_binary_op(math_ops.greater) +le = _broadcasting_binary_op(math_ops.less_equal) +lt = _broadcasting_binary_op(math_ops.less) +pow = _broadcasting_binary_op(math_ops.pow) +shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) +shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) +shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) + + +def _binary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def binary_op_wrapper(x, y, name=None): + return fn(x, y, name=name) + + return binary_op_wrapper + + +transpose = _binary_op(array_ops.transpose) +rev = _binary_op(array_ops.reverse) + +bitcast_convert_type = array_ops.bitcast + + +def broadcast(x, dims, name=None): + x = ops.convert_to_tensor(x) + shape = array_ops.concat( + [constant_op.constant(dims), + array_ops.shape(x)], axis=0) + return array_ops.broadcast_to(x, shape, name=name) + + +def clamp(a, x, b, name=None): + return min(max(a, x, name=name), b, name=name) + + +concatenate = array_ops.concat + + +def conv(lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count=1, + precision_config=None, + name=None): + """Wraps the XLA ConvGeneralDilated operator. + + ConvGeneralDilated is the most general form of XLA convolution and is + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + + Args: + lhs: the input tensor + rhs: the kernel tensor + window_strides: the inter-window strides + padding: the padding to apply at the start and end of each input dimensions + lhs_dilation: dilation to apply between input elements + rhs_dilation: dilation to apply between kernel elements + dimension_numbers: a `ConvolutionDimensionNumbers` proto. + feature_group_count: number of feature groups for grouped convolution. + precision_config: a `PrecisionConfigProto` proto. + name: an optional name for the operator + + Returns: + A tensor representing the output of the convolution. + """ + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_conv( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +convert_element_type = math_ops.cast + + +def dot(lhs, rhs, name=None): + return math_ops.tensordot(lhs, rhs, axes=1, name=name) + + +def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_dot( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +def dynamic_slice(x, starts, sizes, name=None): + # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not + # a compile-time constant. This doesn't exactly mimic the semantics of dynamic + # slice if the slice is out of bounds. + return array_ops.slice(x, starts, sizes, name=name) -# TODO(phawkins): provide wrappers for all XLA operators. dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +# TODO(phawkins): generalize tf.pad to support interior padding, and then remove +# the XLA-specific pad operator. +pad = gen_xla_ops.xla_pad + + +def random_normal(mu, sigma, dims, name=None): + mu = ops.convert_to_tensor(mu) + return random_ops.random_normal( + dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) + + +def random_uniform(minval, maxval, dims, name=None): + minval = ops.convert_to_tensor(minval) + return random_ops.random_uniform( + dims, minval, maxval, dtype=minval.dtype, name=name) + + +recv = gen_xla_ops.xla_recv +reduce = gen_xla_ops.xla_reduce + def reduce_window(operand, init, @@ -61,22 +349,38 @@ def reduce_window(operand, """ window_strides = window_strides or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) - padding_low = [x for (x, _) in padding] - padding_high = [y for (_, y) in padding] return gen_xla_ops.xla_reduce_window( - operand, - init, - reducer, - window_dimensions, - window_strides, - padding_low, - padding_high, + input=operand, + init_value=init, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + computation=reducer, name=name) -recv = gen_xla_ops.xla_recv +def reshape(x, new_sizes, dimensions=None, name=None): + if dimensions is not None: + x = array_ops.transpose(x, dimensions) + x = array_ops.reshape(x, new_sizes, name=name) + return x + + +def select(condition, x, y, name=None): + return array_ops.where(condition, x, y, name) + + +select_and_scatter = gen_xla_ops.xla_select_and_scatter send = gen_xla_ops.xla_send -sort = gen_xla_ops.xla_sort +def slice(x, start_dims, limit_dims, strides): + spec = [ + _slice(start, limit, stride) + for (start, limit, stride) in zip(start_dims, limit_dims, strides) + ] + return x[tuple(spec)] + + +sort = gen_xla_ops.xla_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc new file mode 100644 index 0000000000..32ba6df2e6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "absl/algorithm/container.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( + XlaResourceOpKind op_kind) { + switch (op_kind) { + case XlaResourceOpKind::kRead: + return "Read"; + case XlaResourceOpKind::kWrite: + return "Write"; + case XlaResourceOpKind::kReadWrite: + return "Modify"; + } +} + +static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() { + gtl::FlatMap<StringPiece, XlaResourceOpInfo>* result = + new gtl::FlatMap<StringPiece, XlaResourceOpInfo>; + + auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) { + auto insert_result = + result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); + CHECK(insert_result.second); + }; + + auto kRead = XlaResourceOpKind::kRead; + auto kWrite = XlaResourceOpKind::kWrite; + auto kReadWrite = XlaResourceOpKind::kReadWrite; + + auto kVariable = XlaResourceKind::kVariable; + auto kStack = XlaResourceKind::kStack; + auto kTensorArray = XlaResourceKind::kTensorArray; + + // clang-format off + add("AssignAddVariableOp" , kReadWrite, kVariable); + add("AssignSubVariableOp" , kReadWrite, kVariable); + add("AssignVariableOp" , kWrite, kVariable); + add("ReadVariableOp" , kRead, kVariable); + add("ResourceApplyAdaMax" , kReadWrite, kVariable); + add("ResourceApplyAdadelta" , kReadWrite, kVariable); + add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradDA" , kReadWrite, kVariable); + add("ResourceApplyAdam" , kReadWrite, kVariable); + add("ResourceApplyAddSign" , kReadWrite, kVariable); + add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable); + add("ResourceApplyFtrl" , kReadWrite, kVariable); + add("ResourceApplyFtrlV2" , kReadWrite, kVariable); + add("ResourceApplyGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyPowerSign" , kReadWrite, kVariable); + add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); + add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyRMSProp" , kReadWrite, kVariable); + add("ResourceGather" , kRead, kVariable); + add("ResourceScatterAdd" , kReadWrite, kVariable); + add("ResourceScatterDiv" , kReadWrite, kVariable); + add("ResourceScatterMax" , kReadWrite, kVariable); + add("ResourceScatterMin" , kReadWrite, kVariable); + add("ResourceScatterMul" , kReadWrite, kVariable); + add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdUpdate" , kReadWrite, kVariable); + add("ResourceScatterSub" , kReadWrite, kVariable); + add("ResourceScatterUpdate" , kReadWrite, kVariable); + add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("VarIsInitializedOp" , kRead, kVariable); + add("VariableShape" , kRead, kVariable); + + add("StackV2" , kWrite, kStack); + add("StackCloseV2" , kRead, kStack); + add("StackPopV2" , kReadWrite, kStack); + add("StackPushV2" , kReadWrite, kStack); + + add("TensorArrayV3" , kWrite, kTensorArray); + add("TensorArrayConcatV3" , kRead, kTensorArray); + add("TensorArrayGatherV3" , kRead, kTensorArray); + add("TensorArrayScatterV3" , kWrite, kTensorArray); + add("TensorArrayGradV3" , kRead, kTensorArray); + add("TensorArrayCloseV3" , kRead, kTensorArray); + add("TensorArrayReadV3" , kRead, kTensorArray); + add("TensorArraySizeV3" , kRead, kTensorArray); + add("TensorArraySplitV3" , kWrite, kTensorArray); + add("TensorArrayWriteV3" , kWrite, kTensorArray); + // clang-format on + + return result; +} + +static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& +GetStaticResourceOpInfoMap() { + static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map = + CreateResourceOpInfoMap(); + return *op_info_map; +} + +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { + const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos = + GetStaticResourceOpInfoMap(); + auto it = op_infos.find(op); + return it == op_infos.end() ? nullptr : &it->second; +} + +namespace resource_op_table_internal { +std::vector<StringPiece> GetKnownResourceOps() { + std::vector<StringPiece> result; + for (const auto& p : GetStaticResourceOpInfoMap()) { + result.push_back(p.first); + } + absl::c_sort(result); + return result; +} +} // namespace resource_op_table_internal +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h new file mode 100644 index 0000000000..7f627a64c6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ + +#include <string> +#include <vector> + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +// Exposes information about the resource operations supported by tf2xla in a +// structured form. + +namespace tensorflow { +enum class XlaResourceOpKind { + kRead, // Only reads from resources. + kWrite, // Only writes to resources. + kReadWrite // Reads from and writes to resources. +}; + +enum class XlaResourceKind { + kVariable, // Operates on resource variables. + kStack, // Operates on stacks. + kTensorArray // Operates on tensor arrays. +}; + +class XlaResourceOpInfo { + public: + explicit XlaResourceOpInfo(XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) + : op_kind_(op_kind), resource_kind_(resource_kind) {} + + XlaResourceOpKind kind() const { return op_kind_; } + XlaResourceKind resource_kind() const { return resource_kind_; } + + static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + + private: + XlaResourceOpKind op_kind_; + XlaResourceKind resource_kind_; +}; + +// Returns a XlaResourceOpInfo describing `op` if it is a resource operation +// supported by tf2xla, otherwise returns null (i.e. if this returns null then +// `op` is either not a resource operation or is unsupported by XLA). +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); + +namespace resource_op_table_internal { +// NB! Implementation detail exposed for unit testing, do not use. +// +// Returns the set of resource operations known by this module. +std::vector<StringPiece> GetKnownResourceOps(); +} // namespace resource_op_table_internal + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc new file mode 100644 index 0000000000..0343f80de9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} + +bool HasResourceInputOrOutput(const OpDef& op_def) { + return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) || + absl::c_any_of(op_def.output_arg(), IsResourceArgDef); +} + +TEST(ResourceOperationTableTest, HaveAllResourceOps) { + gtl::FlatMap<string, bool> known_resource_ops; + for (StringPiece known_resource_op : + resource_op_table_internal::GetKnownResourceOps()) { + ASSERT_TRUE( + known_resource_ops.insert({string(known_resource_op), false}).second); + } + + std::vector<string> xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const string& xla_op_name : xla_op_names) { + const OpDef* op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); + if (HasResourceInputOrOutput(*op_def)) { + EXPECT_EQ(known_resource_ops.count(xla_op_name), 1) + << "Unknown resource op " << xla_op_name; + known_resource_ops[xla_op_name] = true; + } + } + + std::vector<string> unnecessary_resource_ops; + for (const auto& pair : known_resource_ops) { + if (!pair.second) { + unnecessary_resource_ops.push_back(pair.first); + } + } + + EXPECT_TRUE(unnecessary_resource_ops.empty()) + << "Stale resource ops:\n" + << absl::StrJoin(unnecessary_resource_ops, "\n"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 66835e69b2..2d7eb8b915 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -65,8 +65,8 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !str_util::StrContains(parsed_device.type, - kDeviceSuffixReplicatedCore)) { + !absl::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { return absl::optional<xla::OpSharding>(); } else { const int core = parsed_device.id; diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc deleted file mode 100644 index 2b0834fe7b..0000000000 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include <string> -#include <utility> -#include <vector> - -namespace tensorflow { -namespace str_util { - -static void ReplaceAll(string* text, StringPiece from, StringPiece to) { - size_t pos = 0; - while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { - text->replace(pos, from.size(), to.data(), to.size()); - pos += to.size(); - if (from.empty()) { - pos++; // Match at the beginning of the text and after every byte - } - } -} - -void ReplaceAllPairs(string* text, - const std::vector<std::pair<string, string>>& replace) { - for (const std::pair<string, string>& from_to : replace) { - ReplaceAll(text, from_to.first, from_to.second); - } -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h deleted file mode 100644 index 51f25009d7..0000000000 --- a/tensorflow/compiler/tf2xla/str_util.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// String utilities that are esoteric enough that they don't belong in -// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally -// useful under xla. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ -#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ - -#include <string> -#include <utility> -#include <vector> - -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace tensorflow { -namespace str_util { - -// Replace all non-overlapping occurrences of the given (from,to) pairs in-place -// in text. If from is empty, it matches at the beginning of the text and after -// every byte. Each (from,to) replacement pair is processed in the order it is -// given. -void ReplaceAllPairs(string* text, - const std::vector<std::pair<string, string>>& replace); - -} // namespace str_util -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc deleted file mode 100644 index 8817f6902a..0000000000 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include <string> -#include <utility> -#include <vector> - -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace str_util { - -class ReplaceAllPairsTest : public ::testing::Test { - protected: - void ExpectReplaceAllPairs( - string text, const std::vector<std::pair<string, string>>& replace, - StringPiece want) { - ReplaceAllPairs(&text, replace); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllPairsTest, Simple) { - ExpectReplaceAllPairs("", {}, ""); - ExpectReplaceAllPairs("", {{"", ""}}, ""); - ExpectReplaceAllPairs("", {{"", "X"}}, "X"); - ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_"); - ExpectReplaceAllPairs("banana", {}, "banana"); - ExpectReplaceAllPairs("banana", {{"", ""}}, "banana"); - ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_"); - ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__"); - ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana"); - ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn"); - ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX"); - ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}", - {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}}, - "a0b123456789c0"); -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 48568c825b..f34af2d67d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -197,8 +197,8 @@ Status RewriteAndPruneGraph( if (!missing_feeds.empty() || !missing_fetches.empty()) { return errors::Aborted( "Post graph-pruning", - ", missing feeds: ", str_util::Join(missing_feeds, ", "), - ", missing fetches: ", str_util::Join(missing_fetches, ", ")); + ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), + ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 7aca889a26..567d212b5e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -20,11 +20,11 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } std::sort(types.begin(), types.end()); constraints.push_back("`" + constraint.name() + "={" + - str_util::Join(types, ",") + "}`"); + absl::StrJoin(types, ",") + "}`"); } std::cout << "`" << kdef->op() << "` | " - << str_util::Join(constraints, "<br>") << std::endl; + << absl::StrJoin(constraints, "<br>") << std::endl; } std::cout << "\nTo regenerate this table, run:\n\n```shell\n" @@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + - str_util::Join(device_names, ",")}, + absl::StrJoin(device_names, ",")}, }; string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ae51446204..2b1f724dc7 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,16 +26,15 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7227df9649..6e5a0198f6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -309,10 +309,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "depends on a parameter")) + absl::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) + absl::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -727,8 +727,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); } @@ -807,12 +806,10 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "Attr T is not found")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found")) << status.error_message(); } @@ -1078,9 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1103,10 +1100,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "is not in the list of allowed values")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "is not in the list of allowed values")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}")) << status.error_message(); } @@ -1130,9 +1127,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) + absl::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 31a41f8719..9e8f5f2a1a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -99,6 +99,25 @@ Status XlaOpKernelContext::ConstantInput(int index, index, context_->input(index).shape().dim_sizes(), constant_literal); } +static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context, + StringPiece name) { + int start, stop; + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + return start; +} + +Status XlaOpKernelContext::ConstantInput(StringPiece name, + xla::Literal* constant_literal) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInput(index, constant_literal); +} + Status XlaOpKernelContext::ConstantInputReshaped( int index, gtl::ArraySlice<int64> new_dims, xla::Literal* constant_literal) { @@ -246,6 +265,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, + int64* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntScalar(index, out); +} + Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); @@ -280,6 +305,12 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, + std::vector<int64>* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntVector(index, out); +} + Status XlaOpKernelContext::ConstantInputReshapedToIntVector( int index, std::vector<int64>* out) { xla::Literal literal; @@ -313,6 +344,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, + xla::Literal* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsInt64Literal(index, out); +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3f21a2bf41..3e26ba4f01 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -106,6 +106,7 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); + Status ConstantInput(StringPiece name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input @@ -117,12 +118,14 @@ class XlaOpKernelContext { // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); + Status ConstantInputAsIntScalar(StringPiece name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector<int64>* out); + Status ConstantInputAsIntVector(StringPiece name, std::vector<int64>* out); // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. @@ -130,6 +133,7 @@ class XlaOpKernelContext { // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 46785bc1f0..e25c7e8c9e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -325,6 +325,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( return kernels; } +/*static*/ std::vector<string> XlaOpRegistry::GetAllRegisteredOps() { + std::vector<string> ops; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& pair : registry.ops_) { + ops.push_back(pair.first); + } + std::sort(ops.begin(), ops.end()); + return ops; +} + /* static */ const std::unordered_set<string>* XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index fc14834ca6..6ce0e2580b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -128,6 +128,9 @@ class XlaOpRegistry { const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns all operations for which there are XLA kernels on any device. + static std::vector<string> GetAllRegisteredOps(); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr // if the op is not registered. static const std::unordered_set<string>* CompileTimeConstantInputs( diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1a8fa627a0..26bd1ac4f7 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -113,6 +113,7 @@ cc_library( ":statusor", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -173,6 +174,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", ], ) @@ -237,11 +239,11 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -259,6 +261,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -301,6 +304,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -320,6 +324,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -341,6 +346,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -359,6 +365,7 @@ cc_library( ":literal_util", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -370,6 +377,7 @@ cc_library( deps = [ ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -379,8 +387,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -391,6 +399,7 @@ cc_library( ":status", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -413,6 +422,7 @@ cc_library( ":types", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -457,6 +467,7 @@ cc_library( ":array2d", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -510,6 +521,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -529,6 +541,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -559,6 +572,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -629,6 +643,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 2d5d078aa7..c8e483712e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -27,12 +27,12 @@ limitations under the License. #include <type_traits> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -507,9 +507,7 @@ class Array { } } - pieces.push_back( - tensorflow::strings::AlphaNum(values_[calculate_index(index)]) - .data()); + pieces.push_back(absl::StrCat(values_[calculate_index(index)])); // Emit comma if it isn't the last element if (index.back() != sizes_.back() - 1) { @@ -527,7 +525,7 @@ class Array { } } } while (next_index(&index)); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } private: diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 340f94fab7..782c966b4c 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -25,11 +25,10 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index a75fffc605..14e7bf1814 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -26,12 +26,11 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index c8b2a1ac73..9ad8ee2014 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -77,6 +77,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -90,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -216,6 +218,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 25608d6616..1fdf8f6260 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -400,7 +400,7 @@ StatusOr<string> Client::ExecutionStatsAsString( int64 nanoseconds = profile.compute_time_ns(); int64 cycle_count = profile.compute_cycle_count(); double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( + return absl::StrCat( "[Execution Statistics] flop count: ", computation_stats.flop_count(), ", transcendental count: ", computation_stats.transcendental_count(), ", compute execution time: ", nanoseconds, " nsec", diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index b6012a0352..040344c9a6 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -41,7 +41,7 @@ CompileOnlyClient::CompileAheadOfTime( metadata); } -int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { +int64 CompileOnlyClient::PointerSizeForTriple(absl::string_view triple) { llvm::Triple llvm_triple( llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); if (llvm_triple.isArch64Bit()) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index a551edeab0..d0c83cbfcc 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -57,7 +57,7 @@ class CompileOnlyClient : public Client { std::unique_ptr<AotCompilationMetadata>* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + static int64 PointerSizeForTriple(absl::string_view triple); private: CompileOnlyService* compiler_service_; diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 5fe28c33df..5a73408db5 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -77,7 +77,7 @@ const absl::optional<string>& ExecutableBuildOptions::generate_hlo_graph() } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { + absl::string_view dirpath) { dump_optimized_hlo_proto_to_ = string(dirpath); return *this; } @@ -89,8 +89,8 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { ExecutableBuildOptions& ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_unoptimized_hlo_proto_to_ = string(dirpath); return *this; } @@ -100,7 +100,7 @@ ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath) { + absl::string_view dirpath) { dump_per_pass_hlo_proto_to_ = string(dirpath); return *this; } diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 904d230981..888d2f28eb 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -62,19 +62,19 @@ class ExecutableBuildOptions { // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath); + absl::string_view dirpath); const absl::optional<string>& dump_optimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath); + absl::string_view dirpath); const absl::optional<string>& dump_unoptimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath); + absl::string_view dirpath); const absl::optional<string>& dump_per_pass_hlo_proto_to() const; // If true, specifies that we should record an HLO profile during execution @@ -83,7 +83,7 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); absl::optional<bool> hlo_profile() const; - void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + void add_disabled_hlo_pass(absl::string_view pass_name) { disabled_hlo_passes_.push_back(std::string(pass_name)); } const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 4d233741bd..8736f18dcf 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -31,7 +31,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -221,5 +221,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 9225b1acd6..e86c10f030 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,6 +17,7 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -39,7 +39,7 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, b = builder->CreateSubBuilder(name); } else { b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + absl::StrCat(name, "_", PrimitiveType_Name(type))); } const Shape scalar = ShapeUtil::MakeShape(type, {}); diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 081fec7ad9..6861521acc 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -61,8 +61,7 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - XlaBuilder b( - tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); + XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 4e7ef66dc5..9f902d7298 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -23,6 +23,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -31,12 +34,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; namespace { @@ -223,8 +225,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { auto build_status = Build(); if (!build_status.ok()) { parent_builder_->ReportError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); + AddStatus(build_status.status(), absl::StrCat("error from: ", name_))); return {}; } return build_status.ConsumeValueOrDie(); @@ -705,8 +706,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand)); VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); + VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector<int64> new_sizes; for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { @@ -717,8 +717,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } } - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; + VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; return Reshape(operand, new_sizes); }); @@ -1013,7 +1012,7 @@ StatusOr<Window> XlaBuilder::MakeWindow( return Status::OK(); } else { return InvalidArgument( - "%s", tensorflow::strings::StrCat( + "%s", absl::StrCat( "Window has different number of window dimensions than of ", x_name, "\nNumber of window dimensions: ", window_dimensions.size(), @@ -1283,7 +1282,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; - if (tensorflow::str_util::StartsWith(call_target_name, "$")) { + if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 3dbf1e5bee..baa2ae5184 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,7 @@ limitations under the License. #include <type_traits> #include <utility> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h index 1a51fdee68..6d51126d88 100644 --- a/tensorflow/compiler/xla/device_util.h +++ b/tensorflow/compiler/xla/device_util.h @@ -21,8 +21,8 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,8 +30,8 @@ namespace xla { // Returns a string that represents the device in terms of platform and ordinal; // e.g. the first CUDA device will be "cuda:0" string DeviceIdentifier(se::StreamExecutor* stream_exec) { - return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", - stream_exec->device_ordinal()); + return absl::StrCat(stream_exec->platform()->Name(), ":", + stream_exec->device_ordinal()); } } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index ffd1fb79e9..693dcb3a3e 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -18,10 +18,10 @@ limitations under the License. #include <algorithm> #include <string> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -36,7 +36,7 @@ namespace xla { DCHECK_GE(multi_index[i], 0); DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" - << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",") + << "\n\tindex: " << absl::StrJoin(multi_index, ",") << "\n\tshape: " << ShapeUtil::HumanString(shape); } diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index b72d190d54..61c26434b1 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -211,7 +211,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { "layout minor_to_major field contains %d elements, " "but shape is rank %lld: {%s}; shape: %s", layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), + absl::StrJoin(layout.minor_to_major(), ", ").c_str(), shape.ShortDebugString().c_str()); } @@ -403,12 +403,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ string LayoutUtil::HumanString(const Layout& layout) { if (IsSparse(layout)) { - return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), - "}"); + return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); } CHECK(IsDense(layout)); - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); + return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); } namespace { diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 89353448e2..989035896b 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -73,5 +74,6 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5d27e4a46b..0d3136b0cc 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -17,9 +17,9 @@ limitations under the License. #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. #include <vector> +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace legacy_flags { @@ -87,7 +87,7 @@ void AllocateFlags() { // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector<string> disabled_passes = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); for (const auto& passname : disabled_passes) { flag_values->add_xla_disable_hlo_passes(passname); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index e9cf435d83..acda438395 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ #include <vector> +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -30,7 +31,7 @@ template <typename T> void parse_xla_backend_extra_options(T* extra_options_map, string comma_separated_values) { std::vector<string> extra_options_parts = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); // The flag contains a comma-separated list of options; some options // have arguments following "=", some don't. @@ -59,8 +60,7 @@ void parse_xla_backend_extra_options(T* extra_options_map, inline bool parse_xla_reduce_precision_option( HloReducePrecisionOptions* options, string option_string) { // Split off "LOCATION" from remainder of string. - std::vector<string> eq_split = - tensorflow::str_util::Split(option_string, '='); + std::vector<string> eq_split = absl::StrSplit(option_string, '='); if (eq_split.size() != 2) { return false; } @@ -80,26 +80,25 @@ inline bool parse_xla_reduce_precision_option( } // Split off "E,M" from remainder of string. - std::vector<string> colon_split = - tensorflow::str_util::Split(eq_split[1], ':'); + std::vector<string> colon_split = absl::StrSplit(eq_split[1], ':'); if (colon_split.size() != 2) { return false; } // Split E and M, and parse. std::vector<int32> bitsizes; - if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',', - &bitsizes) || - bitsizes.size() != 2) { - return false; + for (const auto& s : absl::StrSplit(colon_split[0], ',')) { + bitsizes.emplace_back(); + if (!absl::SimpleAtoi(s, &bitsizes.back())) { + return false; + } } options->set_exponent_bits(bitsizes[0]); options->set_mantissa_bits(bitsizes[1]); // Split off OPS comma-separated list from remainder of string, if the // remainder exists. - std::vector<string> semicolon_split = - tensorflow::str_util::Split(colon_split[1], ';'); + std::vector<string> semicolon_split = absl::StrSplit(colon_split[1], ';'); if (semicolon_split.size() > 2) { return false; } @@ -113,8 +112,7 @@ inline bool parse_xla_reduce_precision_option( options->add_opcodes_to_suffix(i); } } else { - std::vector<string> opcodes = - tensorflow::str_util::Split(opcode_string, ','); + std::vector<string> opcodes = absl::StrSplit(opcode_string, ','); for (const string& opcode : opcodes) { bool found = false; for (int i = 0; i < HloOpcodeCount(); i++) { @@ -132,8 +130,7 @@ inline bool parse_xla_reduce_precision_option( // Process the NAMES string, if it exists. if (semicolon_split.size() == 2) { - std::vector<string> opnames = - tensorflow::str_util::Split(semicolon_split[1], ','); + std::vector<string> opnames = absl::StrSplit(semicolon_split[1], ','); for (const string& opname : opnames) { if (opname.length() > 0) { options->add_opname_substrings_to_suffix(opname); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc index 0ed788a967..6f197aec53 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include <unordered_map> #include <vector> -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index d54f051a1a..30b890737b 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -23,6 +23,8 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,19 +33,16 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using tensorflow::strings::Printf; + constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; // Converts between little and big endian. @@ -1030,9 +1029,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, element_index.push_back(i); std::vector<string> element_pieces; ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); pieces->push_back("\n)"); return; } @@ -1056,8 +1055,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(": "); } else { pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); pieces->push_back("]: "); } pieces->push_back(literal.GetSparseElementAsString(i)); @@ -1183,7 +1181,7 @@ string LiteralBase::ToString(bool print_layout) const { std::vector<string> pieces; CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } void LiteralBase::EachCellAsString( diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index ed9de65299..aad435ed5b 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -26,6 +26,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 6883a6bbab..67a69c2403 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,16 +19,16 @@ limitations under the License. #include <cmath> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +using absl::StrAppend; +using absl::StrCat; using tensorflow::strings::Appendf; using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; namespace xla { namespace literal_comparison { @@ -47,9 +47,9 @@ Status CompareFloatsBitwiseEqual( if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a at index %s", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, - StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double, + "was requested: %s=%g=%a vs %s=%g=%a at array index %s", + StrCat(absl::Hex(ulhs)).c_str(), lhs_double, lhs_double, + StrCat(absl::Hex(urhs)).c_str(), rhs_double, rhs_double, LiteralUtil::MultiIndexAsString(multi_index).c_str()); } return Status::OK(); @@ -65,9 +65,10 @@ Status CompareEqual(NativeT lhs, NativeT rhs, return Status::OK(); } return InvalidArgument( - "Expected equality of these values:\n %s\n %s\nat index %s", - StrCat(lhs).c_str(), StrCat(rhs).c_str(), - LiteralUtil::MultiIndexAsString(multi_index).c_str()); + "first mismatch at array index %s:\n expected value: %s\n actual " + "value: %s", + LiteralUtil::MultiIndexAsString(multi_index).c_str(), StrCat(lhs).c_str(), + StrCat(rhs).c_str()); } // Specializations for floating types that do bitwise comparisons when equality @@ -119,7 +120,8 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, Status result; for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index[dimension] = i; - result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1)); + TF_RETURN_IF_ERROR( + Equal<NativeT>(expected, actual, multi_index, dimension + 1)); } return result; } @@ -251,11 +253,6 @@ class NearComparator { // Runs the comparison between expected and actual literals. Status Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, ToStringTruncated(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, ToStringTruncated(actual_)); - // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); @@ -539,6 +536,62 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds; template <typename NativeT> constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; +Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal<bool>(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal<uint8>(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal<int32>(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal<int64>(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal<uint32>(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal<uint64>(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal<bfloat16>(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal<half>(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal<float>(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal<double>(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal<complex64>(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update(EqualHelper(LiteralSlice(expected, {i}), + LiteralSlice(actual, {i}))); + } + break; + } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); + default: + LOG(FATAL) << "Unsupported primitive type: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + return result; +} + // Helper function for comparing two literals for nearness. Handles tuple-shapes // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. @@ -555,17 +608,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, const auto actual_element = LiteralSlice(actual, {i}); ShapeIndex element_index = shape_index; element_index.push_back(i); - Status res = + Status element_result = NearHelper(expected_element, actual_element, error, detailed_message, miscompare_callback, element_index); - if (!res.ok()) { - string err_message = Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), - res.error_message().c_str()); + if (!element_result.ok()) { + element_result = InvalidArgument( + "Array at shape index %s, %s", element_index.ToString().c_str(), + element_result.error_message().c_str()); if (return_status.ok()) { - return_status = res; + return_status = element_result; } else { - return_status = AppendStatus(return_status, res.error_message()); + return_status = + AppendStatus(return_status, element_result.error_message()); } } } @@ -611,8 +665,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } } - // Non-floating point literal. - return literal_comparison::Equal(expected, actual); + // Non-floating point, non-tuple literal. + return EqualHelper(expected, actual); } } // namespace @@ -668,81 +722,44 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return Status::OK(); } +namespace { + +// If result is an error, extend the error message with the expected and actual +// literals. +Status EmitLiteralsInErrorMessage(const Status& result, + const LiteralSlice& expected, + const LiteralSlice& actual) { + if (result.ok()) { + return result; + } + return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s", + result.error_message().c_str(), + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str()); +} + +} // namespace + Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { VLOG(1) << "expected:"; XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - - TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); - Status result; - switch (expected.shape().element_type()) { - case PRED: - result = Equal<bool>(expected, actual, &multi_index, 0); - break; - case U8: - result = Equal<uint8>(expected, actual, &multi_index, 0); - break; - case S32: - result = Equal<int32>(expected, actual, &multi_index, 0); - break; - case S64: - result = Equal<int64>(expected, actual, &multi_index, 0); - break; - case U32: - result = Equal<uint32>(expected, actual, &multi_index, 0); - break; - case U64: - result = Equal<uint64>(expected, actual, &multi_index, 0); - break; - case BF16: - result = Equal<bfloat16>(expected, actual, &multi_index, 0); - break; - case F16: - result = Equal<half>(expected, actual, &multi_index, 0); - break; - case F32: - result = Equal<float>(expected, actual, &multi_index, 0); - break; - case F64: - result = Equal<double>(expected, actual, &multi_index, 0); - break; - case C64: - result = Equal<complex64>(expected, actual, &multi_index, 0); - break; - case TUPLE: { - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - result.Update( - Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); - } - break; - } - case TOKEN: - // Tokens have no on-device representation and are trivially equal. - return Status::OK(); - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - - if (result.ok()) { - return Status::OK(); - } - - return AppendStatus( - result, tensorflow::strings::Printf("\nexpected: %s\nactual: %s", - ToStringTruncated(expected).c_str(), - ToStringTruncated(actual).c_str())); + Status result = EqualHelper(expected, actual); + return EmitLiteralsInErrorMessage(result, expected, actual); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error, bool detailed_message, const MiscompareCallback& miscompare_callback) { - return NearHelper(expected, actual, error, detailed_message, - miscompare_callback, - /*shape_index=*/{}); + VLOG(1) << "Expected literal:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "Actual literal:"; + XLA_VLOG_LINES(1, actual.ToString()); + Status result = + NearHelper(expected, actual, error, detailed_message, miscompare_callback, + /*shape_index=*/{}); + return EmitLiteralsInErrorMessage(result, expected, actual); } string ToStringTruncated(const LiteralSlice& literal) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index c5d0c2c267..aef87e46d8 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -1324,8 +1326,8 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0<uint32>(1234); Status status = literal->BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), - "bit widths are different")); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "bit widths are different")); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1819,21 +1821,20 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { "false"); ASSERT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(int64{2})); + absl::StrCat(int64{2})); ASSERT_EQ( LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); + absl::StrCat(double{2.0})); ASSERT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(static_cast<float>(half{2.0}))); - ASSERT_EQ( - LiteralUtil::CreateSparse<complex64>( - dimensions, indices, - std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); + absl::StrCat(static_cast<float>(half{2.0}))); + ASSERT_EQ(LiteralUtil::CreateSparse<complex64>( + dimensions, indices, + std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + ->GetSparseElementAsString(1), + absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index d4c7b76b28..95d93acfe8 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,19 +33,16 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; + // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template <typename FromNativeT, typename ToNativeT> @@ -287,7 +286,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8( - tensorflow::StringPiece value) { + absl::string_view value) { auto literal = absl::make_unique<Literal>( ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())})); for (int i = 0; i < value.size(); ++i) { @@ -477,7 +476,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) { /* static */ string LiteralUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice<int64> multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); + return StrCat("{", absl::StrJoin(multi_index, ","), "}"); } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 1109021ea8..3d28c070f2 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -28,6 +28,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -43,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -187,7 +187,7 @@ class LiteralUtil { const Array4D<NativeT>& values, const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value); + static std::unique_ptr<Literal> CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 69ef4f7a2f..2f22e02c3e 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -18,6 +18,7 @@ limitations under the License. #include <cctype> #include <unordered_map> +#include "absl/strings/str_cat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -84,7 +85,7 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) { if (end_of_line == string::npos) { end_of_line = report.size(); } - tensorflow::StringPiece line(report.data() + pos, end_of_line - pos); + absl::string_view line(report.data() + pos, end_of_line - pos); // TODO(b/34779244): Figure out how to do this without the verbose log-line // prefix. The usual way didn't compile on open source. @@ -152,8 +153,8 @@ void MetricTableReport::AppendCategoryTable() { if (text.empty()) { text = "[no category]"; } - tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ", - entry_name_, ")"); + absl::StrAppend(&text, " (", category.entries.size(), " ", entry_name_, + ")"); AppendTableRow(text, category.metric_sum, metric_sum); // Show the top entries in the category. @@ -177,9 +178,9 @@ void MetricTableReport::AppendCategoryTable() { } const int64 remaining_categories = categories.size() - categories_shown; if (remaining_categories > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_categories, - " more categories)"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_categories, " more categories)"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -206,9 +207,9 @@ void MetricTableReport::AppendEntryTable() { } const int64 remaining_entries = entries_.size() - entries_shown; if (remaining_entries > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_entries, - " more ", entry_name_, ")"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_entries, " more ", entry_name_, ")"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -241,10 +242,10 @@ double MetricTableReport::UnaccountedMetric() { string MetricTableReport::MetricString(double metric) { // Round to integer and stringify. - string s1 = tensorflow::strings::StrCat(std::llround(metric)); + string s1 = absl::StrCat(std::llround(metric)); // Code below commafies the string, e.g. "1234" becomes "1,234". - tensorflow::StringPiece sp1(s1); + absl::string_view sp1(s1); string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index 818fb1d3fe..062d8ed99b 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -18,9 +18,8 @@ limitations under the License. #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -108,7 +107,7 @@ class MetricTableReport { // Append all parameters to the report. template <typename... Args> void AppendLine(Args... args) { - tensorflow::strings::StrAppend(&report_, std::forward<Args>(args)..., "\n"); + absl::StrAppend(&report_, std::forward<Args>(args)..., "\n"); } // Represents a set of entries with the same category_text. diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 55c4a80e29..012df87551 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -64,7 +64,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read( tensorflow::gtl::ArraySlice<float> field = result->data<float>(); char* data = tensorflow::bit_cast<char*>(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; + tensorflow::StringPiece sp; // non-absl OK auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; + tensorflow::StringPiece sp; // non-absl OK auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a91336c3ac..2d8fe434b0 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index e1060d54e2..08dccb3ee1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,6 +109,7 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" +#include "third_party/absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -896,7 +897,7 @@ tensorflow::ImportNumpy(); if (o != Py_None) { StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); if (!statusor.ok()) { - PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); + PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); SWIG_fail; } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 4b9970eadc..f2f99c1745 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -191,8 +192,8 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) { PyObject* result = PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr); if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to call method of shape object:", method)); + return error( + absl::StrCat("Failed to call method of shape object:", method)); } return result; }; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 32723849a6..aa826aa770 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -99,6 +99,7 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -176,6 +177,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -241,6 +243,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -320,6 +323,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -346,7 +350,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -398,7 +402,7 @@ cc_library( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -460,6 +464,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -564,6 +569,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -587,6 +593,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -629,6 +636,7 @@ cc_library( "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) @@ -662,6 +670,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -684,6 +693,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -735,6 +745,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -832,6 +843,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -851,6 +863,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -886,6 +899,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -896,6 +910,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -930,6 +945,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -939,6 +955,7 @@ tf_cc_test( deps = [ ":buffer_liveness", ":hlo", + ":hlo_dataflow_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -974,6 +991,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1021,6 +1039,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1113,6 +1132,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1199,6 +1219,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1215,6 +1236,7 @@ cc_library( "//tensorflow/compiler/xla:util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1330,6 +1352,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -1355,6 +1378,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1419,6 +1443,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1457,6 +1482,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1497,6 +1523,7 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -1511,6 +1538,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1701,6 +1729,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1714,6 +1743,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1807,6 +1837,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1839,6 +1870,7 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1871,6 +1903,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -1898,6 +1931,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -1916,6 +1950,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1936,6 +1971,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1978,6 +2014,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2014,6 +2051,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2034,6 +2072,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2093,6 +2132,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2144,6 +2184,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2166,6 +2207,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2235,6 +2277,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2278,6 +2321,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", ], ) @@ -2400,6 +2444,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2633,6 +2678,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2666,8 +2712,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -2681,6 +2727,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2717,8 +2764,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -2752,6 +2799,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], alwayslink = 1, @@ -2769,6 +2817,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2997,8 +3046,8 @@ cc_library( ":hlo_creation_utils", ":tuple_util", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -3114,6 +3163,7 @@ cc_library( "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -3150,6 +3200,7 @@ cc_library( "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -3164,6 +3215,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", # fixdeps: keep + "@com_google_absl//absl/strings", ], ) @@ -3182,6 +3234,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index b86b7d2e71..c236453fc7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -1989,9 +1990,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() - << (convert != nullptr ? tensorflow::strings::StrCat( - "\nvia convert: ", convert->ToString()) - : ""); + << (convert != nullptr + ? absl::StrCat("\nvia convert: ", convert->ToString()) + : ""); // Do not fold interior padding into ReduceWindow since the backends do not // support it. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index c48196e861..b864c372fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; - tensorflow::StringPiece name() const override { return "algsimp"; } + absl::string_view name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 427069af5f..bb63ea26d4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,13 +36,12 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -using ::testing::ElementsAre; namespace xla { namespace { +using ::testing::ElementsAre; + namespace op = xla::testing::opcode_matchers; AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { @@ -51,7 +52,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloVerifiedTestBase { + public: + AlgebraicSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -2143,9 +2149,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { root->operand(0)->opcode() == HloOpcode::kDot) { auto lhs_shape = root->operand(0)->operand(0)->shape(); auto rhs_shape = root->operand(0)->operand(1)->shape(); - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", - tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); + return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ", + absl::StrJoin(rhs_shape.dimensions(), "x")); } return "UNEXPECTED CHANGE"; }; @@ -2660,11 +2665,10 @@ struct PadReduceWindowEffectiveBroadcastCase { bool should_become_broadcast; string ToTestCaseName() const { - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(input_spatials, ","), ";", - tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", - tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, - ";", should_become_broadcast); + return absl::StrCat(absl::StrJoin(input_spatials, ","), ";", + absl::StrJoin(symmetric_pad_spatials, ","), ";", + absl::StrJoin(reduce_window_spatials, ","), ";", + prepend_a, ";", should_become_broadcast); } }; @@ -2852,7 +2856,12 @@ struct DotOfConcatTestSpec { class DotOfConcatSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface<DotOfConcatTestSpec> {}; + public ::testing::WithParamInterface<DotOfConcatTestSpec> { + public: + DotOfConcatSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that we transform // dot(const, concat(A, B, C)) @@ -3025,7 +3034,12 @@ struct DotOfGatherTestSpec { class DotOfGatherSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface<DotOfGatherTestSpec> {}; + public ::testing::WithParamInterface<DotOfGatherTestSpec> { + public: + DotOfGatherSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // input: dot(DS(ctA), ctB)) // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index d0806d24a2..5115a14df0 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -18,6 +18,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 1bc3796fa4..4a6a78daf0 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -130,7 +130,7 @@ class Backend { // Return a string identifier for the given device, eg: "GPU:3". string device_name(int device_ordinal) const { - return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal); + return absl::StrCat(platform_->Name(), ":", device_ordinal); } // Returns true if the devices with the given ordinals are equivalent from diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index be6fbcc9e3..a16b85a0a5 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -78,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( return true; } -tensorflow::StringPiece BatchDotSimplification::name() const { +absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index c0ca8d8eba..79d37f08d3 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -28,7 +28,7 @@ namespace xla { class BatchDotSimplification : public HloPassInterface { public: StatusOr<bool> Run(HloModule* module) override; - tensorflow::StringPiece name() const override; + absl::string_view name() const override; private: StatusOr<bool> ElideDegenerateBatchDimensionFromBatchDot( diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a6..b342acb025 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -24,7 +24,12 @@ namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloVerifiedTestBase { + public: + BatchDotSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 7ae202c583..76e32174f3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface { rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; - tensorflow::StringPiece name() const override { return "batchnorm_expander"; } + absl::string_view name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index f62ab12319..aba0d9bb5b 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index c939838709..5dcd31b83d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16ConversionFolding() override = default; - tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + absl::string_view name() const override { return "bfloat16-fold"; } // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 16e99b5722..32573ed355 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum and sort which can have a tuple - // output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleSort(HloInstruction* sort) override; - static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16NormalizationVisitor visitor(computation, bfloat16_support); @@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( return Status::OK(); } -Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape())) { - return HandleInstruction(crs); - } else { - return HandleMultipleOutputs(crs); - } -} - -Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return HandleInstruction(sort); - } else { - return HandleMultipleOutputs(sort); - } -} - Status BFloat16NormalizationVisitor::HandleMultipleOutputs( HloInstruction* hlo) { std::vector<PrimitiveType> operand_types(hlo->operand_count()); @@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + if ((hlo->opcode() == HloOpcode::kSort || + hlo->opcode() == HloOpcode::kCrossReplicaSum) && + ShapeUtil::IsTuple(hlo->shape())) { + return HandleMultipleOutputs(hlo); + } return HandleInstruction(hlo); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 2a60fe0af3..30b6346312 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16Normalization() override = default; - tensorflow::StringPiece name() const override { return "bf16-normalization"; } + absl::string_view name() const override { return "bf16-normalization"; } // Run BF16 normalization on the given computation. Returns whether the // computation was changed. @@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface { ~BFloat16MixedPrecisionRemoval() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "bf16-mixed-precision-removal"; } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 49ae5320b0..b08705d4c2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase { StatusOr<bool> result = normalization.Run(module); EXPECT_IS_OK(result.status()); - HloVerifier verifier(/*allow_mixed_precision=*/true); + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); EXPECT_IS_OK(verifier.Run(module).status()); return result.ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 02b8cad089..1ee64971ab 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface { ~BFloat16Propagation() override = default; - tensorflow::StringPiece name() const override { - return "bfloat16-propagation"; - } + absl::string_view name() const override { return "bfloat16-propagation"; } // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 0f08e7c52b..c8c36ae60e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" @@ -36,20 +37,17 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { +namespace { +using absl::StrAppend; using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; - -namespace { template <typename T> string ColocatedBufferSetsToString(const T& container, const char* title) { @@ -236,8 +234,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { } string BufferAllocation::Slice::ToString() const { - return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_, - ", size:", size_, "}"); + return absl::StrCat("{index:", index(), ", offset:", offset_, + ", size:", size_, "}"); } BufferAllocation::Slice BufferAllocation::GetSlice( @@ -678,9 +676,9 @@ string BufferAssignment::Stats::ToString() const { string BufferAssignment::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); + absl::StrAppend(&output, "BufferAssignment:\n"); for (auto& allocation : allocations_) { - tensorflow::strings::StrAppend(&output, allocation.ToString()); + absl::StrAppend(&output, allocation.ToString()); } return output; } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 810d597e73..8d0ac3b84a 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -89,13 +89,13 @@ string BufferLiveness::ToString() const { pieces.push_back( tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b)); if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) { return false; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 3ffb7de65f..26e26e316d 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - // Runs BufferLiveness on this computation. - // Returns whether buffer interference is detected between tuple-shaped - // parameter and root instructions at tuple element 1. - bool Run(const bool update_uses_tuple_element1, - const bool fuse_gte0 = false) { + std::unique_ptr<HloModule> BuildModule(const bool update_uses_tuple_element1, + const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); // Create output tuple. - auto tuple_root = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = CreateNewModule(); - module->AddEntryComputation(BuildDummyComputation()); - auto* computation = module->AddEmbeddedComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); + auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. if (update_uses_tuple_element1) { computation->CreateFusionInstruction( @@ -666,7 +664,14 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { computation->CreateFusionInstruction({gte0}, HloInstruction::FusionKind::kLoop); } + return module; + } + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); // Run BufferLiveness on 'module'. auto liveness = BufferLiveness::Run( module.get(), @@ -674,8 +679,24 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } + bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); + // Run BufferLiveness on 'module'. + auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie(); + auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get()); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); + return hlo_ordering->MayInterfere( + dataflow->GetUniqueValueAt(tuple_param0, {1}), + dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow); + } }; // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); + EXPECT_FALSE( + RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases @@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); + EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false, + /*fuse_gte0=*/true)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); + EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true)); } class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index 2bc556a9e2..fdf822c666 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -17,11 +17,10 @@ limitations under the License. #include <iosfwd> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index d6efef5f12..37523a73ff 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -18,20 +18,20 @@ limitations under the License. #include <queue> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrCat; using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::StrCat; string CallContextToString(CallContext context) { switch (context) { @@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { } string CallSite::ToString() const { - return StrCat(instruction()->name(), " calls in context ", - CallContextToString(context()), ": ", - tensorflow::str_util::Join( - called_computations(), ", ", + return StrCat( + instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + absl::StrJoin(called_computations(), ", ", [](string* out, const HloComputation* computation) { out->append(computation->name()); })); diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index c0e95e1578..c5cd88b9ea 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -35,7 +35,7 @@ class CallInliner : public HloPassInterface { static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call); ~CallInliner() override = default; - tensorflow::StringPiece name() const override { return "CallInliner"; } + absl::string_view name() const override { return "CallInliner"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index e75f6f146d..5d85a3f173 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace op = xla::testing::opcode_matchers; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 9c9e373821..601a3e9a01 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -16,13 +16,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 7426672a7a..3079695e96 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime( if (!directory_path.empty()) { HloSnapshot hlo_snapshot; *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = tensorflow::strings::StrCat( - "computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); + string filename = + absl::StrCat("computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index cb61f3da39..af8f7f1027 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -17,9 +17,9 @@ limitations under the License. #include <algorithm> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -52,9 +52,8 @@ string ComputationLayout::ToString() const { for (auto& param_layout : parameter_layouts_) { params.push_back(param_layout.ToString()); } - return tensorflow::strings::StrCat("(", - tensorflow::str_util::Join(params, ", "), - ") => ", result_layout_.ToString()); + return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ", + result_layout_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index afbbea35b8..61b1dba6c9 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -29,12 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; namespace xla { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index b7be3ba605..4ea3a13f28 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -28,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 063261e26d..3de50cbd7f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -27,9 +27,7 @@ namespace xla { // with their true or false computation as appropriate. class ConditionalSimplifier : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "simplify-conditional"; - } + absl::string_view name() const override { return "simplify-conditional"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167..6c477da038 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers; class ConditionalSimplifierTest : public HloVerifiedTestBase { public: + ConditionalSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index f213cc8709..498894737f 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -29,7 +29,7 @@ class ConvolutionFeatureGroupConverter : public HloPassInterface { public: ConvolutionFeatureGroupConverter() {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-feature-group-converter"; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 3e39c1bab1..231d31d960 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -31,18 +33,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { - -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { +using absl::StrAppend; + bool IsEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && @@ -381,7 +378,7 @@ class CopyRemover { } string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); StrAppend(&out, " Buffer values, in dependency order:\n"); for (const HloBuffer& buffer : alias_analysis_.buffers()) { StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); @@ -863,16 +860,16 @@ class CopyRemover { for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { values.push_back(p->value); } - return StrCat("{", - Join(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); } string ToString() const { - string out = StrCat("BufferValueTracker:\n"); + string out = absl::StrCat("BufferValueTracker:\n"); StrAppend(&out, " Def-use chains in each buffer:\n"); for (const ValueNode* head : value_lists_) { StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), @@ -880,10 +877,10 @@ class CopyRemover { const ValueNode* p = head; do { StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), "\n"); p = p->next; diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 5ba64b78a3..f797ee7e4d 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -45,7 +45,7 @@ namespace xla { // InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } // fusion_can_share_buffer: backend specific function that decides whether a // fusion can share buffer with its operand. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 850948b54b..e01fecffd0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -87,6 +87,8 @@ cc_library( ":parallel_task_assignment", ":simple_orc_jit", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ":target_machine_features", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", @@ -232,6 +234,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", "@llvm//:orc_jit", ], ) @@ -279,6 +282,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", @@ -323,6 +327,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -365,6 +370,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -653,6 +659,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -816,6 +823,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -852,6 +860,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index e6fd1499ed..59437e88af 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface { : target_machine_features_(*target_machine_features) {} ~ConvCanonicalization() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-canonicalization"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5116f926f5..279aa42fe2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -27,6 +27,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -101,8 +102,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { @@ -235,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_; const std::unordered_map<const HloInstruction*, int64>& assigned_indices_; }; -} // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine) { - LLVMTargetMachineFeatures target_machine_features(target_machine); +} // namespace - // Optimization pipeline. - HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker<HloVerifier>(); +Status CpuCompiler::RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes through layout assignment"); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass<CpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( @@ -260,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass<BatchDotSimplification>(); pipeline.AddPass<DotDecomposer>(); pipeline.AddPass<ConvolutionFeatureGroupConverter>(); - pipeline.AddPass<ConvCanonicalization>(&target_machine_features); + pipeline.AddPass<ConvCanonicalization>(target_machine_features); { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass<BatchNormExpander>( /*rewrite_training_op=*/true, @@ -291,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, } pipeline.AddPass<IndexedArrayAnalysisPrinterPass>(); pipeline.AddPass<TransposeFolding>( - [&target_machine_features]( - const HloInstruction& dot, + [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, target_machine_features) + return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -309,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass<CpuLayoutAssignment>( - module->mutable_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), target_machine_features); + return pipeline.Run(module).status(); +} + +Status CpuCompiler::RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes after layout assignment"); + // After layout assignment, use a layout-sensitive verifier. + auto& after_layout_assn = + pipeline.AddPass<HloPassPipeline>("after layout assignment"); + after_layout_assn.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>( - "after layout assignement"); + "simplification after layout assignement"); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass<HloPassFix<AlgebraicSimplifier>>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -322,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass<HloDCE>(); pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true); } + pipeline.AddPass<HloElementTypeConverter>(BF16, F32); + // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -335,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass<ParallelTaskAssigner>( - max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); + max_parallelism, ShapeSizeBytesFunction(), target_machine_features); } - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). DCE must be run immediately + // before (and sometime after) copy insertion, to avoid dead code from + // interfering with the rewrites. pipeline.AddPass<HloDCE>(); pipeline.AddPass<FlattenCallGraph>(); pipeline.AddPass<CpuCopyInsertion>(); @@ -350,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, return pipeline.Run(module).status(); } +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile, + &target_machine_features)); + return RunHloPassesAfterLayoutAssn(module, is_aot_compile, + &target_machine_features); +} + namespace { // Align buffers to 16-byte boundaries. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 04e1c48872..47b5edabff 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler { Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine); + // Runs HLO passes up to and including layout assignment. + Status RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features); + + // Runs HLO passes after layout assignment. + Status RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features); + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index 6398d8c98d..d49f7d7cc2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -32,7 +32,7 @@ namespace xla { // (module-scoped). class CpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index c376864c3e..fbcbbbd200 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -22,6 +22,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -35,8 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -177,12 +177,12 @@ Status CpuExecutable::ExecuteComputeFunction( buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { - tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); + absl::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; VLOG(3) << " params = nullptr"; VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", - tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); + absl::StrJoin(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", profile_counters); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 2924b63659..6af724b2a5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface { CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "cpu_hlo_support_checker"; - } + absl::string_view name() const override { return "cpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index e6130c7d76..c3e03056f0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <set> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" @@ -773,8 +774,8 @@ class GatherLoopFusionTest TEST_P(GatherLoopFusionTest, GatherLoopFusion) { const GatherLoopFusionTestSpec& spec = GetParam(); - string hlo_string = tensorflow::strings::StrCat( - "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n", + spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseHloString(hlo_string)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 69acca86bf..bfecbd6e01 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -34,8 +34,8 @@ namespace cpu { // instruction stream. namespace { -using ::absl::nullopt; -using ::absl::optional; +using absl::nullopt; +using absl::optional; using ShouldMakeOperandColMajorCache = tensorflow::gtl::FlatMap<const HloInstruction*, bool>; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index b6039b465e..b8ace57026 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -51,7 +52,7 @@ absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config) { auto it = extra_options_map.find(kLlvmIrDotTilingFactor); int64 tiling_factor; if (it != extra_options_map.end() && - tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + absl::SimpleAtoi(it->second, &tiling_factor)) { return tiling_factor; } return absl::nullopt; @@ -63,8 +64,8 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } -static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, - tensorflow::StringPiece suffix) { +static absl::string_view RemoveSuffix(absl::string_view str, + absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); return str.substr(0, str.size() - suffix.size()); @@ -79,22 +80,21 @@ absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize( return absl::nullopt; } - std::vector<string> tile_components = - tensorflow::str_util::Split(it->second, ':'); + std::vector<string> tile_components = absl::StrSplit(it->second, ':'); CHECK_EQ(tile_components.size(), 3); int64 tile_size_m; int64 tile_size_k; int64 tile_size_n_in_vector_width; - CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); - CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m)); + CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k)); - tensorflow::StringPiece tile_size_n_in_vector_width_str = + absl::string_view tile_size_n_in_vector_width_str = RemoveSuffix(tile_components[2], "*vectwidth"); - CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, - &tile_size_n_in_vector_width)); + CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k, tile_size_n_in_vector_width); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 797392c265..4af16f4fa0 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -146,9 +147,9 @@ class GemvConfig { bool has_addend() const { return has_addend_; } string GetCacheKey() const { - return tensorflow::strings::StrCat( - name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", - tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); } protected: @@ -642,9 +643,7 @@ class TiledSmallGemmEmitter { int64 k() const { return k_; } int64 n() const { return n_; } - string ToString() const { - return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); - } + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } private: const int64 m_; @@ -687,10 +686,10 @@ class TiledSmallGemmEmitter { tile_size_k_(tile_size_k) {} string GetCacheKey() const { - return tensorflow::strings::StrCat( - "gemm_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), - "_", max_vectorization_width(), "_", min_vectorization_width(), "_", - tile_size_m(), "_", tile_size_k()); + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); } PrimitiveType scalar_type() const { return scalar_type_; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 05322faa75..4c2041b556 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6f433b4f30..417a1dba1f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" @@ -67,7 +68,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -502,7 +502,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name) { + absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -846,7 +846,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( loops .AddLoop( 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) + absl::StrCat("k", i)) ->GetIndVarValue(); } llvm::Value* input_feature = @@ -2118,7 +2118,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { gtl::ArraySlice<HloInstruction*> operands(custom_call->operands()); - tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); + absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2687,9 +2687,8 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( auto buf_it = thread_local_buffers_.find(key); if (buf_it == thread_local_buffers_.end()) { llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( - IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, - MinimumAlignmentForShape(target_shape)); + IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()), + &b_, MinimumAlignmentForShape(target_shape)); auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); CHECK(it_inserted_pair.second); buf_it = it_inserted_pair.first; @@ -2753,7 +2752,7 @@ Status IrEmitter::EmitTargetElementLoop( } Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); @@ -2848,7 +2847,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { llvm::Value* IrEmitter::EmitThreadLocalCall( const HloComputation& callee, tensorflow::gtl::ArraySlice<llvm::Value*> parameters, - tensorflow::StringPiece name) { + absl::string_view name) { const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2869,7 +2868,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + absl::StrCat(name, "_retval_addr"), &b_, MinimumAlignmentForPrimitiveType(return_type)); b_.CreateCall( @@ -2886,7 +2885,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( } void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { + absl::string_view name) { b_.CreateCall(FindOrDie(emitted_functions_, &callee), GetArrayFunctionCallArguments( /*parameter_addresses=*/{}, &b_, name, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index c9a1dab62d..99c080b3db 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/string_view.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -107,7 +107,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name); + absl::string_view name); protected: // @@ -239,7 +239,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // function that a map operation applies. StatusOr<llvm::Function*> EmitFunction( HloComputation* function, // The function to emit. - tensorflow::StringPiece + absl::string_view function_name_suffix); // Used for LLVM IR register names. // Emits a call to a thread local function (e.g. to the computation nested @@ -251,14 +251,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* EmitThreadLocalCall( const HloComputation& callee, tensorflow::gtl::ArraySlice<llvm::Value*> parameters, - tensorflow::StringPiece name); + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to // the parameters and return values for these computations so there is no need // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); + void EmitGlobalCall(const HloComputation& callee, absl::string_view name); // Returns the buffer to which a global call to `callee` would have written // its result. @@ -285,7 +284,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); Status EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator); // Emits a memcpy from the source instruction's result value to the diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2db4d000f5..784045313d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -189,7 +190,7 @@ void IrFunction::Initialize(const string& function_name, llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), b_->getInt64(offset), AsStringRef(name))); } @@ -200,7 +201,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // address buffer). std::vector<llvm::Value*> GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; @@ -211,13 +212,13 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( } else { parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); + AsStringRef(absl::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -320,8 +321,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index a41cbb64cd..ee7595f6e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -116,7 +116,7 @@ class IrFunction { // Returns an array of compute function call argument ir values. std::vector<llvm::Value*> GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 8560e4296a..aedb069dce 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector<llvm_ir::IrArray::Index> -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { CHECK_NE(index_type, nullptr); CHECK(!ShapeUtil::IsTuple(shape_)); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 076c683ca5..a604e1db22 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 286d407ca6..b4c0c09ec0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -217,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( // Outline 'instruction' in 'computation' for parallel task assignment. auto* call = module->OutlineExpressionFromComputation( - {instruction}, - tensorflow::strings::StrCat("parallel_", instruction->name()), + {instruction}, absl::StrCat("parallel_", instruction->name()), computation); // Set assigned dimension partitioning to 'instruction'. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 8becc8fa23..a99cd99c14 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface { target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cpu-parallel-task-assigner"; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index ee272b5f4f..a84ee78b19 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -36,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b026aef3fe..bf98064647 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -170,15 +170,14 @@ namespace { bool RegisterKnownJITSymbols() { CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); -#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ - do { \ - auto* function_address = \ - reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ - registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ - CHECK_EQ( \ - tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ - "__xla_cpu_runtime_" #base_name); \ +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ } while (false) REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4635fa5d74..2384166fd2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -110,6 +110,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -123,6 +124,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 6fcce42eaa..fcd87b36b3 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include <cctype> #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 973aac8766..9457e57d7b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include <cctype> #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android"; struct IntrinsicTestSpec { HloOpcode opcode; - tensorflow::StringPiece triple; - tensorflow::StringPiece features; - tensorflow::StringPiece check_lines; + absl::string_view triple; + absl::string_view features; + absl::string_view check_lines; }; // Tests that unary functions get lowered using intrinsic calls. @@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest features = ""; } - return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), - features.empty() ? "" : "_With", - features.c_str()); + return absl::StrCat(opcode.c_str(), "_On_", triple.c_str(), + features.empty() ? "" : "_With", features.c_str()); } }; diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index 56b28fd22d..c326beb899 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -29,7 +29,7 @@ class Defuser : public HloPassInterface { public: Defuser() {} ~Defuser() override {} - tensorflow::StringPiece name() const override { return "defuser"; } + absl::string_view name() const override { return "defuser"; } // Run defusion on the given module. Returns whether the module was // changed. diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb..37d1895d41 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class DefuserTest : public HloVerifiedTestBase { + public: + DefuserTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Returns the number of fusion instructions in the module. int FusionCount() { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index 48e4471499..ba2a674d9a 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -27,9 +27,7 @@ namespace { class ControlDepRemover : public HloPassInterface { public: ControlDepRemover() = default; - tensorflow::StringPiece name() const override { - return "control-dep-remover"; - } + absl::string_view name() const override { return "control-dep-remover"; } StatusOr<bool> Run(HloModule* module) override { bool changed = false; diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index cc1695b7f8..7be70add2f 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -33,7 +33,7 @@ namespace xla { class Despecializer : public HloPassInterface { public: Despecializer(); - tensorflow::StringPiece name() const override { return "despecializer"; } + absl::string_view name() const override { return "despecializer"; } StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 690b5df514..275e6cc61d 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,13 +19,13 @@ limitations under the License. #include <type_traits> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 20c6bafe7c..6ec4893f7a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 1959b687f1..fc38e31700 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface { DotDecomposer(bool decompose_batch_dot = true) : decompose_batch_dot_(decompose_batch_dot) {} ~DotDecomposer() = default; - tensorflow::StringPiece name() const override { return "dot_decomposer"; } + absl::string_view name() const override { return "dot_decomposer"; } // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index cc7a87f9e8..26af67cc1c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" @@ -39,17 +40,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrCat; using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrCat; namespace { @@ -306,18 +306,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( {operand_value->getType()}, b_); } case HloOpcode::kSign: { - bool is_signed = - primitive_util::IsSignedIntegralType(op->shape().element_type()); + CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type())) + << op->shape().element_type(); auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto cmp = b_->CreateICmpEQ(operand_value, GetZero(type)); - if (is_signed) { - auto ashr = - b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); - return Select(cmp, GetZero(type), b_->CreateOr(ashr, 1)); - } else { - return Select(cmp, GetZero(type), GetOne(type)); - } + auto ashr = b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); + return Select(cmp, GetZero(type), b_->CreateOr(ashr, 1)); } case HloOpcode::kNegate: return b_->CreateNeg(operand_value); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index d3efab3614..3cccec9862 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -28,7 +28,7 @@ namespace xla { // points-to analysis (see b/36865746 for details). class FlattenCallGraph : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + absl::string_view name() const override { return "flatten-call-graph"; } // Duplicates computations called from multiple call- or while-nodes to // flatten the call graph. diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index c1fc8574da..7bd9ea5984 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -25,7 +25,7 @@ namespace xla { // nevertheless have a minimum level of support. class GatherExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "gather_expander"; } + absl::string_view name() const override { return "gather_expander"; } StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index fbef487ac8..e53f525517 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -129,6 +129,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -186,6 +187,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm//:core", "@llvm//:support", @@ -231,6 +233,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", ], @@ -347,6 +350,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -384,6 +388,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -402,6 +407,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -496,6 +502,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -527,6 +534,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -687,6 +695,7 @@ cc_library( "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm//:core", ], @@ -775,6 +784,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", ], ) @@ -888,9 +898,8 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 6a285a6b98..f22c2a8add 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include <cmath> +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -74,9 +74,8 @@ ENTRY MaxDifference { %error = f32[SIZE] divide(%sub_abs, %denominator) ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 })"; - auto size_string = std::to_string(num_elements); - return tensorflow::str_util::StringReplace( - kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); + return absl::StrReplaceAll(kF16CompHloText, + {{"SIZE", absl::StrCat(num_elements)}}); } StatusOr<F16BufferComparator> F16BufferComparator::Create( diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 7833a4077e..854a2f50b2 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,11 +17,11 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index e09cde9abf..6e2e330edd 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -54,9 +54,7 @@ namespace gpu { // BatchNormRewriter. class CudnnBatchNormRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "cudnn_batchnorm_rewriter"; - } + absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 7b172812c3..18a76e8c26 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,11 +17,11 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5a8fc76e85..3d421ebb69 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -21,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -128,14 +128,14 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind, string AlgorithmToString(const AlgorithmDesc& algo) { if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + return absl::StrCat(algo.algo_id(), "+TC"); } - return tensorflow::strings::StrCat(algo.algo_id()); + return absl::StrCat(algo.algo_id()); } string NumBytesToString(int64 bytes) { - return tensorflow::strings::StrCat( - tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); + return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", + bytes, "B)"); } // Acquires a process-global lock on the device pointed to by the given diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 472de2ff0f..f76d273e8c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -39,7 +39,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { Compiler* compiler) : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-algorithm-picker"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index 0c0578d888..fbe7e98494 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -26,7 +26,7 @@ namespace gpu { // backwards-input convolutions into CustomCall HLOs that call into cuDNN. class CudnnConvolutionRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-rewriter"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 7b0d9e53d6..68086c86e9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator { "Can't allocate twice from a ScratchBufAllocator."); } if (byte_size > scratch_.size()) { - return se::port::InternalError(tensorflow::strings::StrCat( + return se::port::InternalError(absl::StrCat( "Can't allocate ", byte_size, " bytes from a ScratchBufAllocator of size ", scratch_.size())); } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 9b6de115ad..2460d951bd 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -43,16 +45,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gpu { +using absl::StrAppend; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrAppend; namespace { // Returns whether operand is a floating-point literal with the given value. diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 0cdddf8bcf..def595d217 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,10 +17,10 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 9b86e5315b..1bd88233e1 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -19,12 +19,12 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -289,11 +289,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " - << tensorflow::str_util::Join(users, ", ", - [](string* out, HloInstruction* user) { - tensorflow::strings::StrAppend( - out, user->name()); - }) + << absl::StrJoin(users, ", ", + [](string* out, HloInstruction* user) { + absl::StrAppend(out, user->name()); + }) << " }"; // Remove 'fusion' instruction. CHECK_EQ(0, fusion->user_count()); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 4c523a66de..7e3f5775b8 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -34,7 +34,7 @@ namespace gpu { // class FusionMerger : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "fusion merger"; } + absl::string_view name() const override { return "fusion merger"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 74282c568c..2c02ec2584 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -17,8 +17,8 @@ limitations under the License. #include <functional> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 0c6f9b511f..8ffae18fe8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -27,7 +27,7 @@ namespace gpu { // inserting kCopy instructions. class GpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 09a1d9c12b..627a05e240 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index d63e213d2b..bbb3340760 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface { GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "gpu_hlo_support_checker"; - } + absl::string_view name() const override { return "gpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 286547ebae..fbc8ddf599 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { // Enumerate all combinations of shapes. for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { for (int constrained_param_no : {0, 4}) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 8c11cd0541..0e205b9c02 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -24,16 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index c349063c71..f544bcc919 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -215,7 +215,7 @@ bool IsReductionToVector(const HloInstruction& reduce) { // This emits a device-side call to // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, +llvm::Value* EmitPrintf(absl::string_view fmt, tensorflow::gtl::ArraySlice<llvm::Value*> arguments, llvm::IRBuilder<>* builder) { std::vector<llvm::Type*> argument_types; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 5d23a3d018..a35e250101 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -126,7 +126,7 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo); bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, +llvm::Value* EmitPrintf(absl::string_view fmt, tensorflow::gtl::ArraySlice<llvm::Value*> arguments, llvm::IRBuilder<>* builder); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 561c683879..76e069fc41 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/string_view.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bda2986202..84043689bd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" @@ -90,10 +91,10 @@ namespace { using absl::InlinedVector; using absl::nullopt; using absl::optional; +using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -801,8 +802,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), // // // // and threads_per_block is a multiple of warpSize. - // reduce_kernel<<<num_blocks, threads_per_block>>>(); - // + // reduce_kernel // auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 6305396635..d856299889 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -16,11 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,8 +41,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because + absl::string_view ptx = executable.ptx(); + // Convert absl::string_view to se::port::StringPiece because // StreamExecutor uses the latter. loader_spec_->AddCudaPtxInMemory( se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 6bd9c58f83..ccf082c4c6 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -35,6 +35,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc index 12a8a59488..a3c74507dd 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -86,7 +86,7 @@ void IrDumpingPassManager::run(llvm::Module &module) { const llvm::PassInfo *PI = llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); const string basename = ReplaceFilenameExtension( - tensorflow::io::Basename(input_filename_), + absl::string_view(tensorflow::io::Basename(input_filename_)), tensorflow::strings::Printf( "pass-%02d.before.%s.ll", i, (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index cce6e48141..e18d7e764a 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -54,9 +56,7 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -107,8 +107,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, << ", " << compute_capability.second << ") ." << "Defaulting to libdevice for compute_" << libdevice_version; } - return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version, - ".10.bc"); + return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc"); } // Gets the GPU name as it's known to LLVM for a given compute capability. If @@ -138,15 +137,16 @@ static string GetSmName(std::pair<int, int> compute_capability) { << "Defaulting to telling LLVM that we're compiling for sm_" << sm_version; } - return tensorflow::strings::StrCat("sm_", sm_version); + return absl::StrCat("sm_", sm_version); } // Convenience function for producing a name of a temporary compilation product // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, - tensorflow::StringPiece extension) { - return ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); + absl::string_view extension) { + return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( + llvm_ir::AsString(input_filename))), + extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -167,7 +167,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { // Returns the TargetMachine, given a triple. std::unique_ptr<llvm::TargetMachine> GetTargetMachine( - llvm::Triple triple, tensorflow::StringPiece cpu_name, + llvm::Triple triple, absl::string_view cpu_name, const HloModuleConfig& hlo_module_config) { std::string error; const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); @@ -243,9 +243,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, } // Emits the given module to a bit code file. -void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { +void EmitBitcodeToFile(const Module& module, absl::string_view filename) { std::error_code error_code; - llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code, + llvm::ToolOutputFile outfile(string(filename).c_str(), error_code, llvm::sys::fs::F_None); if (error_code) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); @@ -266,8 +266,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-nvptx.dummy"), + ReplaceFilenameExtension( + absl::string_view(tensorflow::io::Basename(module_id)), + "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -332,8 +333,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { - return tensorflow::errors::Internal(tensorflow::strings::StrCat( - "Error linking libdevice from ", libdevice_path)); + return tensorflow::errors::Internal( + absl::StrCat("Error linking libdevice from ", libdevice_path)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 54e0e140de..9654175bfa 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -20,11 +20,11 @@ limitations under the License. #include <string> #include <utility> +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index 9ef9bc3a50..3b2c3591d9 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace { @@ -52,14 +52,13 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename, return module; } -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension) { +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension) { auto pos = filename.rfind('.'); - tensorflow::StringPiece stem = - pos == tensorflow::StringPiece::npos - ? filename - : tensorflow::StringPiece(filename.data(), pos); - return tensorflow::strings::StrCat(stem, ".", new_extension); + absl::string_view stem = pos == absl::string_view::npos + ? filename + : absl::string_view(filename.data(), pos); + return absl::StrCat(stem, ".", new_extension); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h index a6daeca95a..60f4926849 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h @@ -18,8 +18,8 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace llvm { class LLVMContext; @@ -41,8 +41,8 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename, // // For example: // ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension); +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 5575f6c0c6..9fb6f569ae 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -49,7 +49,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // If possible, we want to pick a reduce operand of the fusion root, // because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*inst)) { return inst; } } @@ -64,7 +64,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, auto get_element_shape = [&](const HloInstruction* element_instr) { // Special handling of kReduce instructions -- the fusion // applies to the first operand. - if (element_instr->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*element_instr)) { return element_instr->operand(0)->shape(); } return element_instr->shape(); @@ -141,10 +141,15 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) { } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - // We can fuse reduces and loop fusions. - return IsInputFusibleReduction(instr) || - (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop); + // We can fuse reduces and loop fusions. Elementwise instructions can be fused + // with any other instruction. + // TODO(b/112957171): This should use the same isFusible logic as + // instruction_fusion. + return instr->IsFusable() && + (IsInputFusibleReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr->IsElementwise()); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -178,28 +183,16 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this - // improves things. + // improves things. Also disable fusing standalone input-fusible reduces into + // loop fusions. CHECK(instr1->opcode() == HloOpcode::kFusion); if ((instr2->opcode() == HloOpcode::kFusion && instr1->fusion_kind() != instr2->fusion_kind()) || - (instr2->opcode() != HloOpcode::kFusion && + (IsReductionToVector(*instr2) && instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { return false; } - // Multi-output loop fusions must have equal output shapes to be lowered. - if (instr1->fusion_kind() == HloInstruction::FusionKind::kLoop) { - Shape shape1 = instr1->IsMultiOutputFusion() - ? instr1->shape().tuple_shapes(0) - : instr1->shape(); - Shape shape2 = instr2->IsMultiOutputFusion() - ? instr2->shape().tuple_shapes(0) - : instr2->shape(); - if (!ShapeUtil::Equal(shape1, shape2)) { - return false; - } - } - // Do this check last, as it may be expensive. return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 072f885bc1..c822c94f1b 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,19 +15,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { +namespace op = xla::testing::opcode_matchers; + using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( @@ -47,7 +47,7 @@ const char kModulePrefix[] = R"( TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) mul = f32[6400]{0} multiply(p1.1, p1.1) @@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { const.1 = f32[] constant(1) p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) @@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -256,6 +256,50 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { + // Fusing a reduce into a loop fusion would require changing the fusion kind. + // That's not supported yet. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { @@ -341,7 +385,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -361,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_add { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -388,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) @@ -429,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_element_wise { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -456,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) @@ -497,7 +541,7 @@ TEST_F(MultiOutputFusionTest, TEST_F(MultiOutputFusionTest, ProducerConsumerFusionReduceUnfriendlyLoopFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( mixed_input_layouts_computation { p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 5868c1a42e..695feadb11 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -22,6 +22,8 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" @@ -85,7 +87,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -140,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass<GpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -156,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -203,7 +206,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. pipeline.AddPass<ConvolutionFeatureGroupConverter>(); pipeline.AddPass<CudnnConvolutionRewriter>(); @@ -218,9 +222,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } { - HloPassPipeline pipeline("layout_assignment"); + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); pipeline.AddPass<GpuLayoutAssignment>( hlo_module->mutable_entry_computation_layout(), stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -266,17 +283,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix<HloPassPipeline> fusion("fusion"); - fusion.AddInvariantChecker<HloVerifier>(); + fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true); fusion.AddPass<FusionMerger>(); fusion.AddPass<GpuMultiOutputFusion>(); fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); + fusion.AddPass<HloDCE>(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker<HloVerifier>(); + reduce_pipeline.AddInvariantChecker<HloVerifier>( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -302,7 +322,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -352,9 +373,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { string vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, &vmin_str, &vdot_str) || - !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || - !tensorflow::strings::safe_strto64(vmin_str, &vmin) || - !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + !absl::SimpleAtoi(vmaj_str, &vmaj) || + !absl::SimpleAtoi(vmin_str, &vmin) || + !absl::SimpleAtoi(vdot_str, &vdot)) { LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path << " --version:\n" << out; @@ -466,7 +487,7 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major, tensorflow::SubProcess ptxas_info_dumper; std::vector<string> ptxas_args = { ptxas_path, ptx_path, "-o", cubin_path, - tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; + absl::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -674,7 +695,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); + ir_dump_directory, absl::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 192359f026..11dc56a64f 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -32,9 +32,7 @@ namespace gpu { // TODO(jlebar): Also pad dots. class PadForTensorCores : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "pad for tensor cores"; - } + absl::string_view name() const override { return "pad for tensor cores"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc index 99e7580b82..104af48c82 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -29,7 +29,12 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -using PadForTensorCoresTest = HloVerifiedTestBase; +class PadForTensorCoresTest : public HloVerifiedTestBase { + public: + PadForTensorCoresTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ParseAndVerifyModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 67e51509e4..a622e894ed 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -26,7 +26,7 @@ namespace gpu { // padding, so that they can be lowered to cuDNN convolution. class PadInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "pad insertion"; } + absl::string_view name() const override { return "pad insertion"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 3838fee674..ca57cacb98 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( unroll_factor_(unroll_factor) {} std::vector<llvm_ir::IrArray::Index> -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index b82a23419d..cc7da2e73b 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index cca35316f0..15d1e269cc 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -27,13 +27,22 @@ namespace { class GpuKernelTilingTest : public GpuCodegenTest { protected: - GpuKernelTilingTest() { + GpuKernelTilingTest() {} + + // Most tests in this file want to skip layout assignment, but a few need it + // enabled. + HloModuleConfig ConfigWithLayoutAssignment() { + return GetModuleConfigForTest(); + } + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - config_.set_debug_options(debug_options); // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout_assignment"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; } - HloModuleConfig config_; }; TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { @@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // + // We must enable layout assignment in order for this test to work correctly. + // AlgebraicSimplifier removes copy1; it's added back by layout assignment, + // which respects the module's entry computation layout. But if we don't run + // layout assignment...well, nobody else adds the copy back. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) })"; - // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // Check that a call to llvm.nvvm.barrier0 is not generated. As in + // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment + // here. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest, })"; // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 9622936306..0f2d5568ca 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + // Disable layout assignment for this test. Layout assignment does not expect + // fusions to be present, and so it does the wrong thing. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); config.set_debug_options(debug_options); const char *const kMultiOutputFusionModule = R"( diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index bdb062837c..141f321938 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -144,16 +144,15 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn( string ThunkSchedule::ToString() const { string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - tensorflow::strings::StrAppend(&result, "\t", - thunk->hlo_instruction()->ToString(), "\n"); + absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); } - tensorflow::strings::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "Dependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { - tensorflow::strings::StrAppend( - &result, "\t", dependent->hlo_instruction()->name(), " depends on ", - dependency->hlo_instruction()->name(), "\n"); + absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(), + " depends on ", dependency->hlo_instruction()->name(), + "\n"); } } return result; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c5f3906356..40183de96e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier; + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 31431f115f..a2be89511b 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -23,6 +23,7 @@ limitations under the License. #include <string> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -43,8 +43,7 @@ namespace { // Adds a computation to the given HLO module which adds a scalar constant to // its parameter and returns the result. HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = - HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); + auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0ca489846e..0986da65cb 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,15 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; // Data structure used to construct the alias analysis. Thrown away after alias // analysis is complete. This data structure keeps track of which sets of @@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const { } string HloAliasAnalysis::ToString() const { - string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); + string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { @@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference( if (ordering.MayInterfere(*values[i - 1], *values[i], dataflow_analysis())) { VLOG(1) << "In buffer " << buffer.id() << " containing values:\n " - << Join(values, ", ", - [](string* out, const HloValue* value) { - StrAppend(out, value->ToShortString()); - }) + << absl::StrJoin(values, ", ", + [](string* out, const HloValue* value) { + StrAppend(out, value->ToShortString()); + }) << "\nValue " << values[i - 1]->ToShortString() << " may interfere with value " << values[i]->ToShortString(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index e16413f361..6c11a073b7 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,15 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; - bool HloBuffer::operator==(const HloBuffer& other) const { bool equal = id() == other.id(); if (equal) { @@ -59,10 +56,11 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const { } string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return absl::StrCat( + "HloBuffer ", id_, ", values: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 4c036ea1bf..cf95b112d7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -25,6 +25,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -37,13 +40,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::StrCat; +using absl::StrCat; std::unique_ptr<HloComputation> HloComputation::Builder::Build( HloInstruction* root_instruction) { @@ -136,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) { } string after_param = original_name.substr(index + param_underscore.size()); int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + if (absl::SimpleAtoi(after_param, &numeric_suffix)) { return StrCat(original_name.substr(0, index + param_underscore.size()), new_param_no); } @@ -805,11 +806,10 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { } } VLOG(3) << "Unreachable roots:" - << tensorflow::str_util::Join( - unreachable_roots, "\n\t", - [](string* out, const HloInstruction* hlo) { - tensorflow::strings::StrAppend(out, hlo->ToString()); - }); + << absl::StrJoin(unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + absl::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -980,8 +980,7 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } -HloInstruction* HloComputation::GetInstructionWithName( - tensorflow::StringPiece name) { +HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { auto instructions_in_computation = instructions(); auto it = absl::c_find_if( instructions_in_computation, diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index faa33f0f90..8d9b694977 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -367,7 +367,7 @@ class HloComputation { // Returns the instruction in this computation that has name `name`. Returns // null if there is no such computation. - HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + HloInstruction* GetInstructionWithName(absl::string_view name); int64 unique_id() const { return unique_id_; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 331480bd02..4557983a9c 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -25,7 +25,7 @@ namespace xla { // computation on constants. class HloConstantFolding : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "constant_folding"; } + absl::string_view name() const override { return "constant_folding"; } // Run constant folding operations on the given module. Returns whether the // module was changed (constant expressions folded). diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index c4e27dc558..0ceb6a2968 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -16,14 +16,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" namespace xla { +using absl::StrCat; using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -336,7 +337,7 @@ StatusOr<HloInstruction*> BroadcastZeros( StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature( ArraySlice<const Shape*> domain, const Shape& range, - tensorflow::StringPiece name) { + absl::string_view name) { HloComputation::Builder b{std::string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 5ff8946fb0..1bc6d09b45 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -177,7 +177,7 @@ StatusOr<HloInstruction*> BroadcastZeros( // a value of type `range`. StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature( tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range, - tensorflow::StringPiece name); + absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 5e2b348bdd..a28c03599a 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface { : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations) {} ~HloCSE() override = default; - tensorflow::StringPiece name() const override { return "cse"; } + absl::string_view name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common // subexpressions were found and eliminated). diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 01840a56e2..1d35757b42 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,8 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -79,8 +78,8 @@ bool MultiDynamicSliceUseShareSameIndices( } // namespace -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; HloDataflowAnalysis::HloDataflowAnalysis( const HloModule& module, bool ssa_form, bool bitcast_defines_value, @@ -977,28 +976,22 @@ Status HloDataflowAnalysis::Verify() const { bool HloDataflowAnalysis::DoesNotUseOperandBuffer( const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + HloInstruction* fusion_param = + user->fused_parameter(use.operand_number); + const HloValue& value = + GetValueDefinedAt(fusion_param, use.operand_index); + return value.uses().empty(); } + return false; } } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index f4abc7a7c7..a1678d4943 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -138,7 +138,8 @@ class HloDataflowAnalysis { // Returns true if 'user' cannot possibly use the buffer at 'index' in // 'operand'. Returns false otherwise. // - // REQUIRES: 'operand' is an operand of 'user'. + // 'operand' does not have to be an operand of 'user'. This can be the case + // with indirect uses. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4755c4a0cf..d1a96c10f8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1963,6 +1963,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); } +// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the +// parameter tuple. +TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto t0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0)); + auto t1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1)); + // Swap the tuple elements. + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0})); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); + // The same holds for the parameter tuple, except that the tuple elements are + // swapped in 'tuple'. + EXPECT_TRUE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion)); + EXPECT_FALSE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion)); +} + class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 4e244494d6..1fe69b1395 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -36,7 +36,7 @@ namespace xla { class HloDCE : public HloPassInterface { public: ~HloDCE() override {} - tensorflow::StringPiece name() const override { return "dce"; } + absl::string_view name() const override { return "dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index af904647f8..72185698c9 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext { StatusOr<bool> Run(); private: - // Inserts a kDomain instruction between operand and instruction in case - // the attribute (ie, sharding) values change between root and instruction. - // Returns the newly inserted kDomain instruction, or nullptr if no kDomain - // instruction was necessary. - StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand); - HloModule* module_; HloDomainIsolator* isolator_; }; -StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* root, - HloInstruction* operand) { - HloInstruction* domain = nullptr; - std::unique_ptr<HloInstruction> domain_instruction = - isolator_->creator_(instruction, root, operand); - if (domain_instruction != nullptr) { - domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); - } - return domain; -} - StatusOr<bool> HloDomainIsolator::RunContext::Run() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); @@ -76,10 +55,11 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() { root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, root, operand)); + HloInstruction* domain = + isolator_->creator_(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); ++added_domains; } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index bb1537766c..d36631fc2f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -38,12 +38,12 @@ class HloDomainIsolator : public HloPassInterface { // instruction differes from the attribute of the root (the second // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. - using DomainCreator = std::function<std::unique_ptr<HloInstruction>( + using DomainCreator = std::function<HloInstruction*( HloInstruction*, HloInstruction*, HloInstruction*)>; explicit HloDomainIsolator(DomainCreator creator); - tensorflow::StringPiece name() const override { return "domain_isolator"; } + absl::string_view name() const override { return "domain_isolator"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index f855f2a1fc..575149c8b8 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,10 +20,10 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -63,7 +63,7 @@ class DomainMetadata { // Returns the metadata type. A unique identifier which describes the real // metadata type. - virtual tensorflow::StringPiece Kind() const = 0; + virtual absl::string_view Kind() const = 0; // Compares the metadata object with another one and returns true if the // two matches. diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index c859e05f02..97bc8ef604 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover(tensorflow::StringPiece kind, + HloDomainRemover(absl::string_view kind, std::function<Status(const DomainMetadata::Domain&, const DomainMetadata* metadata)> normalizer) - : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + : kind_(kind), normalizer_(std::move(normalizer)) {} - tensorflow::StringPiece name() const override { return "domain_remover"; } + absl::string_view name() const override { return "domain_remover"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 2654929bf0..79e78ee2d0 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -29,6 +29,11 @@ namespace xla { namespace { class HloDomainTest : public HloVerifiedTestBase { + public: + HloDomainTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -46,9 +51,8 @@ class HloDomainTest : public HloVerifiedTestBase { // Checks whether there is a kDomain instruction in the edge between the // instruction and the operand. - bool HasDomainEdge(HloModule* module, - tensorflow::StringPiece instruction_name, - tensorflow::StringPiece operand_name) { + bool HasDomainEdge(HloModule* module, absl::string_view instruction_name, + absl::string_view operand_name) { HloInstruction* instruction = FindInstruction(module, instruction_name); HloInstruction* operand = FindInstruction(module, operand_name); CHECK_NE(instruction, nullptr); @@ -66,7 +70,7 @@ class HloDomainTest : public HloVerifiedTestBase { return false; } - StatusOr<HloModule*> ParseModule(tensorflow::StringPiece hlo_string) { + StatusOr<HloModule*> ParseModule(absl::string_view hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); ParseAndVerifyModule(hlo_string, config); @@ -84,7 +88,7 @@ class OpNameMetadata : public DomainMetadata { return absl::make_unique<OpNameMetadata>(opname_); } - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override { const OpNameMetadata* other_ptr = @@ -98,16 +102,16 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - static tensorflow::StringPiece KindName() { return "opname"; } + static absl::string_view KindName() { return "opname"; } private: string opname_; }; // Creator function for OpNameMetadata domains. -std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand) { +HloInstruction* OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } @@ -115,9 +119,9 @@ std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction, absl::make_unique<OpNameMetadata>(root->metadata().op_name()); std::unique_ptr<DomainMetadata> user_side_metadata = absl::make_unique<OpNameMetadata>(instruction->metadata().op_name()); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); + return operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, std::move(operand_side_metadata), + std::move(user_side_metadata))); } Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, @@ -144,7 +148,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -186,7 +190,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -213,7 +217,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -250,7 +254,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -304,7 +308,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator sharding_isolator(CreateShardingDomain); + HloDomainIsolator sharding_isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); @@ -358,7 +362,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -447,7 +451,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -506,7 +510,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc index 751fc677e2..dc514ae3e5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() { TF_RET_CHECK(instruction->user_side_metadata().Kind() == instruction->operand_side_metadata().Kind()) << instruction->ToString(); - kinds.insert(instruction->user_side_metadata().Kind().ToString()); + kinds.insert(string(instruction->user_side_metadata().Kind())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 8e53cf97f8..81d6d69a8c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface { public: HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {} - tensorflow::StringPiece name() const override { return "domain_verifier"; } + absl::string_view name() const override { return "domain_verifier"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 2b109225d0..44ded2c2fa 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface { HloElementTypeConverter(PrimitiveType eliminate_type, PrimitiveType replace_with_type); - tensorflow::StringPiece name() const override { - return "element_type_converter"; - } + absl::string_view name() const override { return "element_type_converter"; } // Returns the pass on the module and returns whether the module was modified. StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index fb90049491..ca1c4dd0e9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 4b8e6260ac..c3af15c6a8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,7 +52,10 @@ static std::array<bool, 2> use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, public HloVerifiedTestBase { protected: - HloEvaluatorTest() : use_bfloat16_(GetParam()) { + HloEvaluatorTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique<HloEvaluator>(); } @@ -1216,7 +1219,12 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase { + public: + HloEvaluatorPreciseReduceTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index eba80c0f19..460ae2b5ec 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::AllOf; using ::testing::ContainsRegex; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index f8ade39e8c..59c628e945 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -26,6 +26,10 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -40,27 +44,24 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" -using ::absl::nullopt; -using ::absl::optional; -using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; -using ::tensorflow::io::JoinPath; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::StringReplace; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { namespace hlo_graph_dumper { namespace { +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; +using tensorflow::Env; +using tensorflow::WriteStringToFile; +using tensorflow::io::JoinPath; + // Helpers for Printf and Appendf. template <typename T> struct PrintfConvert { @@ -217,9 +218,8 @@ string NodeColorAttributes(ColorScheme color) { // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. -string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", - ">", /*replace_all=*/true); +string HtmlLikeStringSanitize(absl::string_view s) { + return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}}); } // Tries to generates a human-readable one-word description of the given @@ -322,7 +322,7 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) { // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). class HloDotDumper { public: - HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + HloDotDumper(const HloComputation* computation, absl::string_view label, const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), @@ -457,7 +457,7 @@ labelloc = t; tooltip = " "; // DOT graphs accept a stylesheet as a URI. So naturally, an inline // stylesheet is a data URI! -stylesheet=" +stylesheet=< data:text/css, @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); svg text { @@ -466,7 +466,7 @@ stylesheet=" } %s -" +> )"; @@ -559,10 +559,10 @@ stylesheet=" } } - return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); + return Printf(fmt, graph_label, StrJoin(edge_css_rules, "\n")); } -string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } +string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { CHECK_EQ(instr->opcode(), HloOpcode::kFusion); @@ -854,7 +854,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { + if (absl::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); @@ -896,7 +896,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( } } } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { @@ -1084,8 +1084,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::str_util::StartsWith(instr->name(), - HloOpcodeString(instr->opcode()))) { + if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) { return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = @@ -1113,7 +1112,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { instr->metadata().source_line())); } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } string HloDotDumper::GetInstructionNodeBackendConfig( @@ -1160,8 +1159,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { constexpr int kMaxShapeLen = 64; if (instr_shape.length() > kMaxShapeLen) { instr_shape = StrCat( - tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), - "..."); + absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "..."); } lines.push_back(instr_shape); } @@ -1178,7 +1176,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { 100 * hlo_cycles_executed / total_cycles_executed)); } } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } // Gets the total number of array elements in the given shape. For tuples, this @@ -1271,7 +1269,7 @@ string HloDotDumper::GetInstructionTrivialComputationStr( HtmlLikeStringSanitize(*computation_type))); } } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } const HloInstruction* HloDotDumper::GetNodeForEdge( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1d7a062c55..064c53252c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,12 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::HasSubstr; string TestName() { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 668ed9d6c3..2bb9de686f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -24,6 +24,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -41,17 +46,15 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; /* static */ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( @@ -664,8 +667,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier, + const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier, const absl::optional<int64>& all_reduce_id) { return absl::make_unique<HloAllReduceInstruction>( shape, operands, reduce_computation, replica_groups, barrier, @@ -688,7 +690,7 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { + HloInstruction* token_operand, absl::string_view outfeed_config) { return absl::make_unique<HloOutfeedInstruction>( outfeed_shape, operand, token_operand, outfeed_config); } @@ -1066,7 +1068,7 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target) { + absl::string_view custom_call_target) { return absl::make_unique<HloCustomCallInstruction>(shape, operands, custom_call_target); } @@ -1345,7 +1347,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone( // If names ends with .suffix[0-9]+ then replace with a suffix with the // numeric value incremented. int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { clone->name_ = StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); } else { @@ -1817,7 +1819,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) { string HloInstruction::SignatureString() const { string operands = - Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) { StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1964,7 +1966,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) { // If operand is already been deleted, put `null` to the string output. if (operand == nullptr) { StrAppend(out, "null "); @@ -1984,7 +1986,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } - StrAppend(out, Join(str, " ")); + StrAppend(out, StrJoin(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -2030,8 +2032,9 @@ std::vector<string> HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { - extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", + extra.push_back( + StrCat("calls=", + StrJoin(called_computations(), ", ", [&](string* out, const HloComputation* computation) { StrAppend(out, PrintName(computation->name(), options)); @@ -2068,12 +2071,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString( break; default: if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=\n", - Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, computation->ToString(new_options)); - }))); + extra.push_back(StrCat( + "calls=\n", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); } break; } @@ -2084,11 +2087,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString( } if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", - Join(control_predecessors_, ", ", - [&](string* out, HloInstruction* pre) { - StrAppend(out, - PrintName(pre->name(), options)); - }), + StrJoin(control_predecessors_, ", ", + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); + }), "}")); } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { @@ -2102,10 +2105,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString( string HloInstruction::ToShortString() const { return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", - Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - StrAppend(out, "%", operand->name()); - }), + StrJoin(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, "%", operand->name()); + }), ")"); } @@ -2795,7 +2798,7 @@ string PaddingConfigToString(const PaddingConfig& padding) { [](const PaddingConfig::PaddingConfigDimension& dim) { return dim.interior_padding() != 0; }); - return Join( + return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { StrAppend( @@ -2819,16 +2822,15 @@ string OpMetadataToString(const OpMetadata& metadata) { if (metadata.source_line() != 0) { result.push_back(StrCat("source_line=", metadata.source_line())); } - return Join(result, " "); + return StrJoin(result, " "); } string RandomDistributionToString(const RandomDistribution& distribution) { - return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); + return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); } string PrecisionToString(const PrecisionConfigProto::Precision& precision) { - return tensorflow::str_util::Lowercase( - PrecisionConfigProto::Precision_Name(precision)); + return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2856,8 +2858,8 @@ string ConvolutionDimensionNumbersToString( output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", - Join(output_dims, "")); + return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->", + StrJoin(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -2868,19 +2870,21 @@ string HloInstruction::DotDimensionNumbersToString() const { const DotDimensionNumbers& dnums = *dot_dimension_numbers_; if (!dnums.lhs_batch_dimensions().empty()) { result.push_back(StrCat("lhs_batch_dims={", - Join(dnums.lhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("lhs_contracting_dims={", - Join(dnums.lhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); if (!dnums.rhs_batch_dimensions().empty()) { result.push_back(StrCat("rhs_batch_dims={", - Join(dnums.rhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("rhs_contracting_dims={", - Join(dnums.rhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); - return Join(result, ", "); + return StrJoin(result, ", "); } StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { @@ -2894,7 +2898,7 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } @@ -2907,15 +2911,14 @@ string HloInstruction::PrecisionConfigToString() const { } return StrCat( "operand_precision={", - Join(precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfigProto::Precision_IsValid(precision)) - << precision; - StrAppend( - out, - PrecisionToString( - static_cast<PrecisionConfigProto::Precision>(precision))); - }), + StrJoin(precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfigProto::Precision_IsValid(precision)) + << precision; + StrAppend(out, PrecisionToString( + static_cast<PrecisionConfigProto::Precision>( + precision))); + }), "}"); } @@ -2932,7 +2935,7 @@ StatusOr<PrecisionConfigProto::Precision> StringToPrecision( } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 121a9e55f6..566c1c449a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -34,6 +34,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -47,7 +49,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -222,7 +223,7 @@ class CanonicalNameMap { return iter->second; } - string new_name = tensorflow::strings::StrCat("tmp_", index++); + string new_name = absl::StrCat("tmp_", index++); canonical_name_map[old_name] = new_name; return new_name; } @@ -450,8 +451,7 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier, - const absl::optional<int64>& all_reduce_id); + absl::string_view barrier, const absl::optional<int64>& all_reduce_id); // This op handles the communication of an Alltoall operation. On each core, // the operands are N ops in the same shape, where N is the number of cores @@ -493,7 +493,7 @@ class HloInstruction { // which is a TOKEN. static std::unique_ptr<HloInstruction> CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in @@ -706,7 +706,7 @@ class HloInstruction { // to the given operands. "shape" is the resultant shape. static std::unique_ptr<HloInstruction> CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target); + absl::string_view custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. @@ -1037,6 +1037,8 @@ class HloInstruction { CHECK(has_sharding()); return *sharding_; } + std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } + // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; @@ -1051,7 +1053,10 @@ class HloInstruction { // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = absl::make_unique<HloSharding>(sharding); + sharding_ = std::make_shared<const HloSharding>(sharding); + } + void set_sharding(std::shared_ptr<const HloSharding> sharding) { + sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1652,7 +1657,10 @@ class HloInstruction { bool copy_elision_allowed_ = true; // The sharding, if one exists. - std::unique_ptr<HloSharding> sharding_; + // Uses std::shared_ptr to allow reuse of the same sharding object between + // HloInstructions and other components as HloSharding can be very large for + // many element tuples. + std::shared_ptr<const HloSharding> sharding_; // Fields used by the kDomain instruction. std::unique_ptr<DomainMetadata> operand_side_metadata_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 2a99d4d7c4..a0de253eda 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -19,6 +19,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -29,10 +33,10 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { @@ -160,7 +164,7 @@ HloInstructionProto HloFftInstruction::ToProto() const { std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {StrCat("fft_type=", FftType_Name(fft_type())), - StrCat("fft_length={", Join(fft_length(), ","), "}")}; + StrCat("fft_length={", StrJoin(fft_length(), ","), "}")}; } bool HloFftInstruction::IdenticalSlowPath( @@ -320,10 +324,10 @@ std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl( std::vector<string> replica_group_str; for (const ReplicaGroup& group : replica_groups()) { replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); + StrCat("{", StrJoin(group.replica_ids(), ","), "}")); } result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}")); return result; } @@ -343,11 +347,11 @@ bool HloCollectiveInstruction::IdenticalSlowPath( HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) + const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier, + const absl::optional<int64>& all_reduce_id) : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + cross_replica_sum_barrier_(barrier), all_reduce_id_(all_reduce_id) { AppendComputation(reduce_computation); } @@ -430,7 +434,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const { std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReverseInstruction::IdenticalSlowPath( @@ -469,7 +473,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const { std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloConcatenateInstruction::IdenticalSlowPath( @@ -512,7 +516,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const { std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReduceInstruction::IdenticalSlowPath( @@ -555,7 +559,7 @@ HloInstructionProto HloSortInstruction::ToProto() const { std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloSortInstruction::IdenticalSlowPath( @@ -588,7 +592,7 @@ HloTransposeInstruction::HloTransposeInstruction( Permute(dimensions, shape.dimensions()).begin())) << "shape: " << ShapeUtil::HumanString(shape) << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; + << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -609,7 +613,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const { std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloTransposeInstruction::IdenticalSlowPath( @@ -648,7 +652,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const { std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloBroadcastInstruction::IdenticalSlowPath( @@ -709,7 +713,7 @@ bool HloMapInstruction::IsElementwiseImpl( std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloMapInstruction::IdenticalSlowPath( @@ -767,7 +771,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl( bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); } - return {StrCat("slice={", Join(bounds, ", "), "}")}; + return {StrCat("slice={", StrJoin(bounds, ", "), "}")}; } bool HloSliceInstruction::IdenticalSlowPath( @@ -853,7 +857,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector<string> v = tensorflow::str_util::Split(tmp, ' '); + std::vector<string> v = absl::StrSplit(tmp, ' '); bool first = true; // Concatenate elements in "v" with spaces separating them, but ignoring // empty entries. @@ -1554,12 +1558,13 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl( infeed_shape(), new_operands[0], infeed_config()); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) +HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + absl::string_view outfeed_config) : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + outfeed_config_(outfeed_config) { CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) << "Outfeed shape " << outfeed_shape << " must be compatible with operand shape " << operand->shape(); @@ -1767,7 +1772,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target) + absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()) { @@ -1903,8 +1908,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const { std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return { - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; + return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","), + "}")}; } bool HloDynamicSliceInstruction::IdenticalSlowPath( @@ -1940,17 +1945,17 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); string offset_dims = StrCat("offset_dims={", - Join(gather_dimension_numbers_->offset_dims(), ","), "}"); - string collapsed_slice_dims = - StrCat("collapsed_slice_dims={", - Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = StrCat( + "collapsed_slice_dims={", + StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); string start_index_map = StrCat("start_index_map={", - Join(gather_dimension_numbers_->start_index_map(), ","), "}"); + StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - return Join<std::initializer_list<string>>( + return StrJoin<std::initializer_list<string>>( {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } @@ -1987,7 +1992,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")}; + StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2026,20 +2031,20 @@ HloScatterInstruction::HloScatterInstruction( } string HloScatterInstruction::ScatterDimensionNumbersToString() const { - string update_window_dims = - StrCat("update_window_dims={", - Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string update_window_dims = StrCat( + "update_window_dims={", + StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}"); string inserted_window_dims = StrCat( "inserted_window_dims={", - Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); string scatter_dims_to_operand_dims = StrCat( "scatter_dims_to_operand_dims={", - Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); - return Join<std::initializer_list<string>>( + return StrJoin<std::initializer_list<string>>( {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim}, ", "); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 19e98c6fb4..efdb9e9781 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -248,8 +248,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier, - const absl::optional<int64>& all_reduce_id); + absl::string_view barrier, const absl::optional<int64>& all_reduce_id); // Returns the barrier config used for the CrossReplicaSum implementation of // each backend. @@ -908,7 +907,7 @@ class HloOutfeedInstruction : public HloInstruction { explicit HloOutfeedInstruction(const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, - tensorflow::StringPiece outfeed_config); + absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); @@ -1061,7 +1060,7 @@ class HloCustomCallInstruction : public HloInstruction { public: explicit HloCustomCallInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target); + absl::string_view custom_call_target); const Window& window() const override { CHECK(window_ != nullptr); return *window_; diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 2e01b090be..0e49d343d6 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,20 +17,20 @@ limitations under the License. #include <unordered_map> +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { - -using ::tensorflow::StringPiece; - namespace { +using absl::string_view; + constexpr int kEOF = -1; constexpr int kError = -2; @@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -tensorflow::StringPiece HloLexer::StringPieceFromPointers( - const char* begin, const char* end) const { +absl::string_view HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return tensorflow::StringPiece(begin, end - begin); + return absl::string_view(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - tensorflow::StringPiece identifier = + absl::string_view identifier = StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. @@ -306,8 +306,8 @@ TokKind HloLexer::LexNumberOrPattern() { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_); + CHECK(absl::SimpleAtod(string(token_start_, current_ptr_).c_str(), + &decimal_val_)); return TokKind::kDecimal; } @@ -339,7 +339,7 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + if (absl::SimpleAtoi(slice, &int64_val_)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -375,24 +375,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == tensorflow::StringPiece::npos) { + if (line_offset == absl::string_view::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { +absl::string_view HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == tensorflow::StringPiece::npos + const char* start = line_start == absl::string_view::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); const char* end = - line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; + line_end == absl::string_view::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -404,10 +404,14 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::StringPiece raw = + absl::string_view raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; - if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + // TODO(b/113077997): Change to absl::CUnescape once it works properly with + // copy-on-write std::string implementations. + if (!tensorflow::str_util::CUnescape( // non-absl ok + tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok + &str_val_, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index f9ecd9ccb9..3e2f8bcd52 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -18,10 +18,10 @@ limitations under the License. #include <string> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" @@ -34,7 +34,7 @@ namespace xla { // it directly. class HloLexer { public: - explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + explicit HloLexer(absl::string_view buf) : buf_(buf) { current_ptr_ = buf_.begin(); } @@ -77,7 +77,7 @@ class HloLexer { std::pair<unsigned, unsigned> GetLineAndColumn(LocTy location) const; // Returns the whole line given the location. - tensorflow::StringPiece GetLine(LocTy loc) const; + absl::string_view GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -89,8 +89,8 @@ class HloLexer { // Creates StringPiece with the given begin and end. Exits if the begin > end, // or it's out of the range of the current buffer. - tensorflow::StringPiece StringPieceFromPointers(const char* begin, - const char* end) const; + absl::string_view StringPieceFromPointers(const char* begin, + const char* end) const; tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( const char* begin, const char* end) const; @@ -107,7 +107,7 @@ class HloLexer { TokKind LexNumberOrPattern(); TokKind LexString(); - const tensorflow::StringPiece buf_; + const absl::string_view buf_; const char* current_ptr_; // Information about the current token. diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 18f17b75ae..3a1dd471c6 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include <deque> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -29,17 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { using Worklist = std::deque<const HloInstruction*>; using Workset = std::unordered_set<const HloInstruction*>; -namespace { - void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { if (workset->count(instruction) == 0) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 7e4b883435..5269cad94d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -15,15 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { -using ::tensorflow::str_util::Join; - bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong lhs_contracting_dimensions (got {" - << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" - << lhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") + << "} want {" << lhs_contracting_dim_ << "})"; return false; } @@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong rhs_contracting_dimensions (got {" - << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" - << rhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") + << "} want {" << rhs_contracting_dim_ << "})"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 0a442e77f0..9ace0d76e0 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -306,7 +306,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Shape( return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); } inline ::testing::Matcher<const ::xla::HloInstruction*> Shape( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -316,7 +316,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout( new ::xla::testing::HloShapeAndLayoutMatcher(shape)); } inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -329,7 +329,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding( } // Matcher for Sharding from sharding string inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding( - tensorflow::StringPiece sharding) { + absl::string_view sharding) { return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( ParseSharding(sharding).ValueOrDie())); } diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index d60b76d63f..78167335c8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -410,7 +410,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( string error_message = "The subcomputation to outline has multiple outputs:\n"; for (HloInstruction* output : outputs) { - tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); + absl::StrAppend(&error_message, output->ToString(), "\n"); } LOG(FATAL) << error_message; } @@ -536,8 +536,7 @@ uint64 HloModule::RandomNew64() const { return rng_(); } -HloComputation* HloModule::GetComputationWithName( - tensorflow::StringPiece name) { +HloComputation* HloModule::GetComputationWithName(absl::string_view name) { auto computations_in_module = computations(); auto it = absl::c_find_if( computations_in_module, diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d2e726a0db..cf129b835d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -24,6 +24,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -142,7 +142,7 @@ class HloModule { // Returns the computation in this module that has the name `name`. Returns // null if there is no such computation. - HloComputation* GetComputationWithName(tensorflow::StringPiece name); + HloComputation* GetComputationWithName(absl::string_view name); // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index f9708283eb..9bfa3a5f45 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -19,14 +19,14 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrAppend; +using absl::StrAppend; HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, bool ignore_layouts) @@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); + string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector<string> params; for (const ShapeLayout& param_layout : entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } - StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 29024085c1..12ca2340a6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -31,7 +31,7 @@ namespace xla { class HloModuleDCE : public HloPassInterface { public: ~HloModuleDCE() override {} - tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + absl::string_view name() const override { return "hlo-module-dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 1a4da388e4..b5c7681edd 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -270,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( string cyclic_instructions; for (const auto& state : *visit_state) { if (state.second == VisitState::kVisiting) { - tensorflow::strings::StrAppend(&cyclic_instructions, - state.first->ToString(), "\n"); + absl::StrAppend(&cyclic_instructions, state.first->ToString(), + "\n"); } } // TODO(b/64305524): Improve the error message to print out the diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6c1e015f77..8fe91c7278 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -254,6 +254,10 @@ bool HloOrdering::LiveRangeStrictlyBefore( } // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { + if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), + use.instruction)) { + continue; + } if (!UseIsBeforeValueDefinition(use, b, dataflow)) { VLOG(4) << "use of " << a << " (" << use << ") not before " << b << " is defined"; @@ -317,7 +321,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { } } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) @@ -388,7 +392,7 @@ string SequentialHloOrdering::ToString() const { tensorflow::strings::Printf(" %s", instruction->name().c_str())); } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } std::ostream& operator<<( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 90a493d29f..df789e6222 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -17,6 +17,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -26,22 +29,18 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace { -using ::absl::nullopt; -using ::absl::optional; -using ::tensorflow::StringPiece; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsInts; +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -50,7 +49,7 @@ class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(StringPiece str, const HloModuleConfig& config) + explicit HloParser(absl::string_view str, const HloModuleConfig& config) : lexer_(str), config_(config) {} // Runs the parser. Returns false if an error occurred. @@ -60,7 +59,7 @@ class HloParser { std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return Join(error_, "\n"); } + string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. StatusOr<HloSharding> ParseShardingOnly(); @@ -253,8 +252,8 @@ class HloParser { bool CanBeParamListToShape(); // Logs the current parsing line and the given message. Always returns false. - bool TokenError(StringPiece msg); - bool Error(LocTy loc, StringPiece msg); + bool TokenError(absl::string_view msg); + bool Error(LocTy loc, absl::string_view msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -293,6 +292,17 @@ class HloParser { missing_instruction_hook_; }; +bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) { + for (const auto& split : absl::StrSplit(s, delim)) { + int64 val; + if (!absl::SimpleAtoi(split, &val)) { + return false; + } + out->push_back(val); + } + return true; +} + // Creates replica groups from the provided nested array. groups[i] represents // the replica ids for group 'i'. std::vector<ReplicaGroup> CreateReplicaGroups( @@ -307,7 +317,7 @@ std::vector<ReplicaGroup> CreateReplicaGroups( return replica_groups; } -bool HloParser::Error(LocTy loc, StringPiece msg) { +bool HloParser::Error(LocTy loc, absl::string_view msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; const unsigned col = line_col.second; @@ -317,12 +327,12 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(Join(error_lines, "\n")); + error_.push_back(StrJoin(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } -bool HloParser::TokenError(StringPiece msg) { +bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } @@ -1806,10 +1816,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, std::vector<tensorflow::int64> elems_seen_until_dim( elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - Join(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { - StrAppend(out, num_elems - 1); - }), + StrJoin(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1996,7 +2006,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", Join(index, ", "), "]")); + ": [", StrJoin(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -2173,10 +2183,10 @@ bool HloParser::ParseAttributeHelper( } else { allowed_attrs = StrCat( "Allowed attributes: ", - Join(attrs, ", ", - [&](string* out, const std::pair<string, AttrConfig>& kv) { - StrAppend(out, kv.first); - })); + StrJoin(attrs, ", ", + [&](string* out, const std::pair<string, AttrConfig>& kv) { + StrAppend(out, kv.first); + })); } return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), allowed_attrs.c_str())); @@ -2489,20 +2499,24 @@ bool HloParser::ParseConvolutionDimensionNumbers( } string str = lexer_.GetStrVal(); - // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - // So we replace the "->" with "_" and then split on "_". - str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", - /*newsub=*/"_", - /*replace_all=*/false); - std::vector<string> lhs_rhs_out = Split(str, "_"); - if (lhs_rhs_out.size() != 3) { + std::vector<string> split1 = absl::StrSplit(str, "_"); + if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; } + std::vector<string> split2 = absl::StrSplit(split1[1], "->"); + if (split2.size() != 2) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + absl::string_view lhs = split1[0]; + absl::string_view rhs = split2[0]; + absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs_rhs_out[0].length(); - if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + const tensorflow::int64 rank = lhs.length(); + if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); } @@ -2517,8 +2531,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // lhs { - const string& lhs = lhs_rhs_out[0]; - if (!is_unique(lhs)) { + if (!is_unique(string(lhs))) { return TokenError( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } @@ -2541,8 +2554,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } // rhs { - const string& rhs = lhs_rhs_out[1]; - if (!is_unique(rhs)) { + if (!is_unique(string(rhs))) { return TokenError( StrCat("expects unique rhs dimension numbers, but sees ", rhs)); } @@ -2565,8 +2577,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } // output { - const string& out = lhs_rhs_out[2]; - if (!is_unique(out)) { + if (!is_unique(string(out))) { return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } @@ -2832,7 +2843,7 @@ bool HloParser::ParseDxD(const string& name, // 2D or higher. if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); - if (!SplitAndParseAsInts(str, 'x', result)) { + if (!SplitToInt64s(str, 'x', result)) { return Error(loc, Printf("expects sub-attribute '%s=ixj...'", name.c_str())); } @@ -2852,10 +2863,9 @@ bool HloParser::ParseWindowPad( return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); } string str = lexer_.GetStrVal(); - std::vector<string> padding_str = Split(str, 'x'); - for (int i = 0; i < padding_str.size(); i++) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector<tensorflow::int64> low_high; - if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, "expects padding_low and padding_high separated by '_'"); @@ -2876,10 +2886,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { } LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); - std::vector<string> padding_str = Split(str, 'x'); - for (const auto& padding_dim_str : padding_str) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector<tensorflow::int64> padding_dim; - if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, "expects padding config pattern like 'low_high_interior' or " @@ -3162,7 +3171,7 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, } // namespace StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config) { + absl::string_view str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -3170,39 +3179,38 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString( return parser.ConsumeHloModule(); } -StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str) { +StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) { HloModuleConfig config; return ParseHloString(str, config); } StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( - tensorflow::StringPiece str, tensorflow::StringPiece name) { + absl::string_view str, absl::string_view name) { HloModuleConfig config; HloParser parser(str, config); - auto builder = absl::make_unique<HloComputation::Builder>(name.ToString()); + auto builder = absl::make_unique<HloComputation::Builder>(string(name)); string root_name; TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); std::unique_ptr<HloComputation> computation = builder->Build(); - auto module = absl::make_unique<HloModule>(name.ToString(), config); + auto module = absl::make_unique<HloModule>(string(name), config); module->AddEntryComputation(std::move(computation)); return std::move(module); } -StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) { +StatusOr<HloSharding> ParseSharding(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseShardingOnly(); } -StatusOr<Window> ParseWindow(tensorflow::StringPiece str) { +StatusOr<Window> ParseWindow(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseWindowOnly(); } StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str) { + absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseConvolutionDimensionNumbersOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 6c184bfe9a..0c64b50481 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" @@ -32,32 +33,31 @@ namespace xla { // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config); + absl::string_view str, const HloModuleConfig& config); // Parses the text for a single HLO operation into an HLO module with a function // that runs that operation (with the same parameters) as its entry computation. StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( - tensorflow::StringPiece str, tensorflow::StringPiece name = "single_op"); + absl::string_view str, absl::string_view name = "single_op"); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str); +StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". -StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str); +StatusOr<HloSharding> ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). -StatusOr<Window> ParseWindow(tensorflow::StringPiece str); +StatusOr<Window> ParseWindow(absl::string_view str); // Parses the result of ConvolutionDimensionNumbersToString(), e.g. // "b0f_0io->b0f". StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str); + absl::string_view str); // ParseHloString sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str); +StatusOr<HloSharding> ParseSharding(absl::string_view str); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index f310b36bfb..b3d3ccda74 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -16,20 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include <string> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" -namespace op = ::xla::testing::opcode_matchers; - namespace xla { - namespace { -using ::tensorflow::StringPiece; +namespace op = ::xla::testing::opcode_matchers; +using absl::string_view; struct TestData { string test_name; @@ -1128,8 +1127,8 @@ ENTRY Computation { class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface<TestData> { protected: - static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -1393,15 +1392,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); ExpectHasSubstr( - ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) + ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index 0cddf8fb8f..f1ad0f9b01 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -29,7 +29,7 @@ namespace xla { class HloPassInterface { public: virtual ~HloPassInterface() = default; - virtual tensorflow::StringPiece name() const = 0; + virtual absl::string_view name() const = 0; // Run the pass on the given HLO module. Return whether it modified the // module. diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index d8f1ab916b..df99e131d8 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,22 +17,22 @@ limitations under the License. #include <functional> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { + +using absl::StrAppend; +using absl::StrCat; + void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; @@ -68,7 +68,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { repeated_field.end()); if (!disabled_passes.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << tensorflow::str_util::Join(disabled_passes, ", "); + << absl::StrJoin(disabled_passes, ", "); } auto run_invariant_checkers = [this, diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 3bb1342aa3..1d41a4dac1 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -34,7 +34,7 @@ namespace xla { class HloPassPipeline : public HloPassInterface { public: explicit HloPassPipeline(const string& name) : name_(name) {} - tensorflow::StringPiece name() const override { return name_; } + absl::string_view name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the // pass constructor: diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc index b9cca13870..c3cacd7ce6 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 04e4a29359..9cc1f5a10e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -21,6 +21,8 @@ limitations under the License. #include <string> #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -38,17 +40,14 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Potential optimizations: // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue // of candidates. @@ -207,11 +206,10 @@ class InstructionList { Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" - << tensorflow::str_util::Join(before_instructions, ", ", - [](string* out, Item* item) { - tensorflow::strings::StrAppend( - out, item->instruction->name()); - }) + << absl::StrJoin(before_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) << "}"; // Find the minimal position number of any instruction in @@ -394,10 +392,9 @@ class MemoryUsageTracker { int64 unfinished_user_count; string ToString() const { - return tensorflow::strings::StrCat( - "Buffer ", id, " (defined by ", - defining_instruction->instruction->name(), ", size ", size, - " bytes)"); + return absl::StrCat("Buffer ", id, " (defined by ", + defining_instruction->instruction->name(), ", size ", + size, " bytes)"); } }; @@ -741,29 +738,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, } string MemoryUsageTracker::ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend( - &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", - memory_usage(), " bytes)"); + string output = + absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n"); + absl::StrAppend(&output, + "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); for (auto* item = instruction_list_.first(); item != nullptr; item = instruction_list_.next(item)) { const HloInstruction* instruction = item->instruction; string inprogress = item == in_progress_item_ ? " in-progress" : ""; string placed = item->placed ? " placed" : ""; - tensorflow::strings::StrAppend(&output, " ", instruction->name(), - inprogress, placed, "\n Defines:\n"); + absl::StrAppend(&output, " ", instruction->name(), inprogress, placed, + "\n Defines:\n"); for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_[buffer_id]; string live = IsCurrentlyLive(buffer_id) ? " live" : ""; - tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, - ", ", buffer.unfinished_user_count, - " unfinished uses\n"); + absl::StrAppend(&output, " ", buffer.ToString(), live, ", ", + buffer.unfinished_user_count, " unfinished uses\n"); } - tensorflow::strings::StrAppend(&output, " Uses:\n"); + absl::StrAppend(&output, " Uses:\n"); for (BufferId buffer_id : item->buffers_used) { - tensorflow::strings::StrAppend(&output, " ", - buffers_[buffer_id].ToString(), "\n"); + absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); } } return output; @@ -781,10 +776,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( defined_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); for (const Buffer& buffer : buffers_) { @@ -804,10 +798,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( used_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); } for (const Buffer& buffer : buffers_) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 8f3ae9c621..7bd8a4a544 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -32,7 +32,7 @@ limitations under the License. namespace xla { /*static*/ StatusOr<std::unique_ptr<HloModule>> -HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, +HloRunner::CreateModuleFromString(const absl::string_view hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 65537f07f5..cfc519063e 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -87,8 +87,7 @@ class HloRunner { // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString( - const tensorflow::StringPiece hlo_string, - const DebugOptions& debug_options); + const absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 27cc5361cd..393824d920 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -28,16 +28,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Class implementing a list scheduler of HLO instructions which produces a // sequence which minimizes memory usage by preferring to schedule the node that // frees bigger buffer and defines smaller outputs. diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 903fbbec1a..980dae07ce 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; +using absl::StrCat; +using absl::StrJoin; HloSharding HloSharding::AssignDevice(int64 device_id) { return HloSharding(device_id); @@ -71,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); - int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + int64 leaf_count = RequiredLeaves(tuple_shape); std::vector<HloSharding> flattened_list; - flattened_list.reserve(leaf_count); - for (int64 i = 0; i < leaf_count; ++i) { - flattened_list.push_back(sharding); - } + flattened_list.resize(leaf_count, sharding); return HloSharding(flattened_list); } @@ -92,7 +90,7 @@ string HloSharding::ToString() const { for (const HloSharding& element : tuple_elements_) { parts.push_back(element.ToString()); } - return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + return StrCat("{", absl::StrJoin(parts, ", "), "}"); } if (replicated_) { @@ -101,8 +99,8 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}"); } else { - return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", - Join(tile_assignment_, ","), "}"); + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), + "]", StrJoin(tile_assignment_, ","), "}"); } } @@ -445,7 +443,7 @@ absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const { } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { - return absl::optional<HloSharding>(); + return absl::nullopt; } } return tuple_elements_.front(); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 4c64ac60c5..be51c3f55b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -260,9 +260,9 @@ class HloSharding { bool maximal_; bool tuple_; Array<int64> tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + // Only non-empty when tuple_ is true. If a tuple is empty then one entry is + // present for the root. This is a flattened list of all the leaf shardings in + // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector<HloSharding> tuple_elements_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 6f0353ee5f..a9b3b66934 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -118,13 +118,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, return Status::OK(); } -std::unique_ptr<HloSharding> CloneShardingForDomain( - const HloSharding& sharding) { - auto single_sharding = sharding.ExtractSingleSharding(); +// For tuple shardings if every element have the same sharsing then we want to +// treat them as single element sharsings to insert less domain separation as a +// domain can prevent some optimizations and we want to minimize that from +// happening. +std::shared_ptr<const HloSharding> CloneShardingForDomain( + std::shared_ptr<const HloSharding> sharding) { + auto single_sharding = sharding->ExtractSingleSharding(); if (!single_sharding) { - return absl::make_unique<HloSharding>(sharding); + return sharding; } - return absl::make_unique<HloSharding>(*single_sharding); + return std::make_shared<const HloSharding>(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -280,66 +284,18 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return Status::OK(); } -// Creates a kDomain instruction to be placed between instruction and operand. -// The kDomain instruction will be created only if the sharding differ between -// the instruction and the operand. -std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand) { - const HloSharding* instruction_sharding = - instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* root_sharding = - root->has_sharding() ? &root->sharding() : nullptr; - // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && root_sharding == nullptr) { - return nullptr; - } - // No need for domain if they match. - if (instruction_sharding != nullptr && root_sharding != nullptr && - ShardingMatches(*instruction_sharding, *root_sharding)) { - return nullptr; - } - std::unique_ptr<HloSharding> real_instruction_sharding; - std::unique_ptr<HloSharding> real_operand_sharding; - if (instruction_sharding != nullptr) { - real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); - } - if (root_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*root_sharding); - } - VLOG(3) << "Creating domain:"; - VLOG(3) << " Instruction: " << instruction->name(); - VLOG(3) << " Operand: " << operand->name(); - VLOG(3) << " User side sharding: " - << (real_instruction_sharding != nullptr - ? real_instruction_sharding->ToString() - : "None"); - VLOG(3) << " Operand side sharding: " - << (real_operand_sharding != nullptr - ? real_operand_sharding->ToString() - : "None"); - - std::unique_ptr<DomainMetadata> operand_side_metadata = - absl::make_unique<ShardingMetadata>(std::move(real_operand_sharding)); - std::unique_ptr<DomainMetadata> user_side_metadata = - absl::make_unique<ShardingMetadata>(std::move(real_instruction_sharding)); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); -} - -StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding( +StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding( tensorflow::gtl::ArraySlice<HloInstruction*> instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the // original common sharding. // All the instructions passed to this API are part of the same computation. - const HloSharding* sharding = nullptr; + std::shared_ptr<const HloSharding> sharding; for (HloInstruction* instruction : instructions) { if (instruction->has_sharding()) { if (sharding == nullptr) { - sharding = &instruction->sharding(); + sharding = instruction->sharding_ptr(); } else { TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) << "Sharding " << *sharding << " does not match the one in " @@ -348,10 +304,10 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding( } } if (sharding == nullptr) { - return std::unique_ptr<HloSharding>(); + return std::shared_ptr<const HloSharding>(); } VLOG(4) << "Extracted sharding is " << *sharding; - return CloneShardingForDomain(*sharding); + return CloneShardingForDomain(sharding); } } // namespace @@ -405,7 +361,7 @@ Status ShardingMetadata::NormalizeShardingDomain( TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); } } else { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding, + TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding, ExtractOriginalCommonSharding(domain.instructions)); if (sharding != nullptr) { VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); @@ -417,10 +373,75 @@ Status ShardingMetadata::NormalizeShardingDomain( return Status::OK(); } -std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* root, - HloInstruction* operand) { - return CreateDomain(instruction, root, operand); +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + auto instruction_sharding = instruction->sharding_ptr(); + auto root_sharding = root->sharding_ptr(); + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && root_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { + return nullptr; + } + + if (instruction_sharding != nullptr) { + instruction_sharding = CloneShardingForDomain(instruction_sharding); + } + if (root_sharding != nullptr) { + root_sharding = CloneShardingForDomain(root_sharding); + } + + auto it = domain_cse_map_.find({operand, instruction_sharding}); + if (it != domain_cse_map_.end()) { + return it->second; + } + + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (instruction_sharding != nullptr ? instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (root_sharding != nullptr ? root_sharding->ToString() : "None"); + + HloInstruction* domain = + operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, + absl::make_unique<ShardingMetadata>(root_sharding), + absl::make_unique<ShardingMetadata>(instruction_sharding))); + domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding}, + domain); + return domain; +} + +bool ShardingDomainCreator::DomainCseMapKey::operator==( + const ShardingDomainCreator::DomainCseMapKey& other) const { + if (instruction != other.instruction) { + return false; + } + if (sharding == nullptr && other.sharding == nullptr) { + return true; + } + if (sharding == nullptr || other.sharding == nullptr) { + return false; + } + return *sharding == *other.sharding; +} + +size_t ShardingDomainCreator::DomainCseMapHasher::operator()( + const ShardingDomainCreator::DomainCseMapKey& key) const { + return tensorflow::Hash64Combine( + std::hash<const HloInstruction*>{}(key.instruction), + key.sharding ? key.sharding->Hash() + : static_cast<size_t>(0x297814aaad196e6dULL)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index dc258e4094..7a6b0d9abc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -27,12 +27,12 @@ namespace xla { // A DomainMetadata implementation that internally wraps a sharding attribute. class ShardingMetadata : public DomainMetadata { public: - explicit ShardingMetadata(std::unique_ptr<HloSharding> sharding) + explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding) : sharding_(std::move(sharding)) {} std::unique_ptr<DomainMetadata> Clone() const override; - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override; @@ -40,7 +40,7 @@ class ShardingMetadata : public DomainMetadata { const HloSharding* sharding() const { return sharding_.get(); } - static tensorflow::StringPiece KindName() { return "sharding"; } + static absl::string_view KindName() { return "sharding"; } static StatusOr<const ShardingMetadata*> ToShardingMetadata( const DomainMetadata* metadata); @@ -55,15 +55,33 @@ class ShardingMetadata : public DomainMetadata { const DomainMetadata* metadata); private: - std::unique_ptr<HloSharding> sharding_; + std::shared_ptr<const HloSharding> sharding_; }; -// Given an HLO graph edge between instruction and one of its operands, creates -// a ShardingMetadata based kDomain instruction if the sharding between -// instruction and parent changes. Returns nullptr if there is no need for a -// domain separation. -std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* root, HloInstruction* operand); +// If the sharding between root and instruction changes then returns a +// ShardingMetadata based kDomain instruction what can be used to separate +// operand and instruction. +// Returns nullptr if there is no need for a domain separation. +class ShardingDomainCreator { + public: + HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand); + + private: + // Map from instruction and user sharding to domain users to CSE identical + // domains. + struct DomainCseMapKey { + const HloInstruction* instruction; + std::shared_ptr<const HloSharding> sharding; + + bool operator==(const DomainCseMapKey& other) const; + }; + struct DomainCseMapHasher { + size_t operator()(const DomainCseMapKey& key) const; + }; + std::unordered_map<DomainCseMapKey, HloInstruction*, DomainCseMapHasher> + domain_cse_map_; +}; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 45fc300fca..2341f8ada0 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) { } } +// Tests that empty tuple is supported. +TEST_F(HloShardingTest, EmptySingleTuple) { + HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), + HloSharding::AssignDevice(0)); + EXPECT_TRUE(sharding.ExtractSingleSharding()); +} + TEST_F(HloShardingTest, NestedTuple) { // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index 2ef38821af..d1cf644f82 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -24,7 +24,7 @@ namespace xla { // one arbitrarily to use and delete the others. class HloSubcomputationUnification : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "subcomputation-unification"; } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index b78bfa0cdf..4876533449 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -21,28 +23,25 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; -using ::tensorflow::TensorShapeProto; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; -using ::tensorflow::str_util::Join; namespace xla { namespace hlo_graph_dumper { namespace { +using absl::StrAppend; +using absl::StrCat; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorShapeProto; + string GetOpDefName(const HloInstruction* instruction) { string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); + tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); if (instruction->opcode() == HloOpcode::kFusion) { string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + StrAppend(&name, absl::string_view(fusion_name).substr(1)); } return name; } @@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); + "{", + absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), + "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 14703aaf64..e0c1326177 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -19,6 +19,8 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,16 +32,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; const Shape& HloPosition::shape() const { return ShapeUtil::GetSubshape(instruction->shape(), index); @@ -216,10 +215,11 @@ void HloValueSet::SortAndUniquifyValues() { } string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return StrCat( + "HloValueSet: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } bool HloValueSet::AssignUnionOf( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 7acf58e252..f60c4eab42 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include <set> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -122,29 +123,26 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -namespace { - -Status CheckIsTokenOperand(const HloInstruction* instruction, - int64 operand_no) { +Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( "Expected operand %lld to be token-shaped, actual shape is " "%s:\n%s", - operand_no, ShapeUtil::HumanString(token->shape()).c_str(), + operand_no, StringifyShape(token->shape()).c_str(), instruction->ToString().c_str()); } return Status::OK(); } -Status CheckOperandAndParameter(const HloInstruction* instruction, - int64 operand_number, - const HloComputation* computation, - int64 parameter_number) { +Status ShapeVerifier::CheckOperandAndParameter( + const HloInstruction* instruction, int64 operand_number, + const HloComputation* computation, int64 parameter_number) { const HloInstruction* operand = instruction->operand(operand_number); const HloInstruction* parameter = computation->parameter_instruction(parameter_number); - if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", operand->ToString().c_str(), parameter->ToString().c_str(), @@ -153,8 +151,6 @@ Status CheckOperandAndParameter(const HloInstruction* instruction, return Status::OK(); } -} // namespace - Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -171,13 +167,12 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. - if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), - outfeed->operand(0)->shape())) { + if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed shape to be compatible with operand's shape %s, " + "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), + StringifyShape(outfeed->operand(0)->shape()).c_str(), + StringifyShape(outfeed->outfeed_shape()).c_str(), outfeed->ToString().c_str()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); @@ -258,8 +253,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError( "Expected sort to have to have the same dimensions for the keys and " "the values. Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + StringifyShape(sort->operand(0)->shape()).c_str(), + StringifyShape(sort->operand(1)->shape()).c_str()); } return CheckVariadicShape(sort); } @@ -333,7 +328,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { + for (HloInstruction* fused_param : fusion->fused_parameters()) { + int64 param_no = fused_param->parameter_number(); + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { + return InternalError( + "Shape mismatch between parameter number %lld and its operand in " + "%s.", + param_no, fusion->ToString().c_str()); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { @@ -415,12 +421,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapeUtil::Compatible(conditional_shape, - ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - ShapeUtil::HumanString(conditional_shape).c_str()); + StringifyShape(conditional_shape).c_str()); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -598,52 +603,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } // Check if the output shape matches the expected shape. - bool compatible; + // // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. - switch (instruction->opcode()) { - case HloOpcode::kTupleSelect: - // TupleSelect only defines the top-level buffer, which in this case is - // the tuple, so we cannot allow mixed precision. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: - // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed - // precision is disallowed. - case HloOpcode::kConstant: - case HloOpcode::kBitcast: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kConvert: - case HloOpcode::kCustomCall: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kWhile: - // The above opcodes should match the expected shapes exactly. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - default: - if (allow_mixed_precision_) { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } else { - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } - } - if (!compatible) { + bool equal = [&] { + switch (instruction->opcode()) { + // The opcodes below can't have implicit layout conversions, nor can they + // implicitly transform f32 -> bf16. Fundamentally these are either + // reinterpreting existing data (e.g. kBitcast) or shuffling data around + // without modifying it (e.g. kGetTupleElement, kTupleSelect). + case HloOpcode::kBitcast: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return ShapesSame(instruction->shape(), inferred_shape); + + // We allow arbitrary layout and f32->bf16 transformations on all other + // instructions, although this may be made more strict pending discussion + // in b/112709536. + default: + if (allow_mixed_precision_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(), + inferred_shape); + } else { + return ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + }(); + if (!equal) { return InternalError( - "Expected instruction to have shape compatible with %s, actual " + "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(inferred_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), + StringifyShape(inferred_shape).c_str(), + StringifyShape(instruction->shape()).c_str(), instruction->ToString().c_str()); } return Status::OK(); @@ -688,10 +692,10 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { string ComputationsToString( tensorflow::gtl::ArraySlice<HloComputation*> computations) { - return tensorflow::str_util::Join( - computations, ",", [](string* s, const HloComputation* computation) { - s->append(computation->name()); - }); + return absl::StrJoin(computations, ",", + [](string* s, const HloComputation* computation) { + s->append(computation->name()); + }); } // Verifies various invariants about the structure of the HLO: @@ -827,7 +831,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. + // (shapes equal) with their respective operand. CHECK_EQ(fusion->operands().size(), fused_parameters.size()); std::vector<bool> parameter_numbers(fused_parameters.size(), false); for (auto fused_param : fused_parameters) { @@ -848,13 +852,6 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; - if (!ShapeUtil::Compatible(fused_param->shape(), - fusion->operand(param_no)->shape())) { - return InternalError( - "Shape mismatch between parameter number %lld and its operand in " - "%s.", - param_no, fusion->ToString().c_str()); - } } // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { @@ -916,7 +913,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." - "Found non-compatible shapes for instruction %s.\n" + "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", HloOpcodeString(instruction->opcode()).c_str(), ShapeUtil::HumanString(out_shape).c_str(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 523bf4d70c..b6093d667c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -28,9 +28,9 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier() : allow_mixed_precision_(false) {} - explicit ShapeVerifier(bool allow_mixed_precision) - : allow_mixed_precision_(allow_mixed_precision) {} + explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; @@ -106,13 +106,42 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - // Return true if the shapes of the two operands have the same element type, - // and the result shape either has the same element type as the operand - // shapes or mixed precision is allowed and the result shape and the operand - // shapes have floating point element types. + // Helpers that switch on layout_sensitive_. + bool ShapesSame(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::Equal(a, b) + : ShapeUtil::Compatible(a, b); + } + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) + : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + string StringifyShape(const Shape& s) { + return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) + : ShapeUtil::HumanString(s); + } + + // Checks that the given operand of the given instruction is of type TOKEN. + Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no); + + // Checks that the shape of the given operand of the given instruction matches + // the given parameter of the given computation. + Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number); + + // Returns true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand shapes + // or mixed precision is allowed and the result shape and the operand shapes + // have floating point element types. bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape); + // If the verifier is layout-sensitive, shapes must be equal to what's + // expected. Otherwise, the shapes must simply be compatible. + bool layout_sensitive_; + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -125,14 +154,10 @@ class HloVerifier : public HloPassInterface { public: using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>; - // Uses standard shape inference. - explicit HloVerifier() - : shape_verifier_factory_( - [] { return absl::make_unique<ShapeVerifier>(false); }) {} - - explicit HloVerifier(bool allow_mixed_precision) - : shape_verifier_factory_([allow_mixed_precision] { - return absl::make_unique<ShapeVerifier>(allow_mixed_precision); + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { + return absl::make_unique<ShapeVerifier>(layout_sensitive, + allow_mixed_precision); }) {} // Uses custom shape verification. @@ -140,10 +165,9 @@ class HloVerifier : public HloPassInterface { : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; - tensorflow::StringPiece name() const override { return "verifier"; } + absl::string_view name() const override { return "verifier"; } - // Note: always returns false (no instructions are ever modified by this - // pass). + // Never returns true; no instructions are ever modified by this pass. StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index d764964f3c..70b741353d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -37,13 +37,15 @@ using ::testing::HasSubstr; class HloVerifierTest : public HloTestBase { public: HloVerifierTest() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; class HloVerifierTestAllowMixedPrecision : public HloTestBase { public: HloVerifierTestAllowMixedPrecision() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; TEST_F(HloVerifierTest, NullInstructionParent) { diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index bb5b40a8a8..581b3ce1e0 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -14,20 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { +using absl::StrAppend; +using absl::StrCat; using tensorflow::strings::Appendf; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 6f56c3aa82..b99624460e 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -18,8 +18,8 @@ limitations under the License. #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,7 +29,7 @@ namespace xla { // computation, suitable for consumption by humans. class HumanReadableProfileBuilder { public: - explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, + explicit HumanReadableProfileBuilder(absl::string_view computation_name, int64 total_cycles, double clock_rate_ghz) : computation_name_(std::string(computation_name)), @@ -43,9 +43,8 @@ class HumanReadableProfileBuilder { // Adds an operation to the profile. If you don't know the number of // floating-point ops or bytes touched by the op, or if you don't know how // fast it would run optimally, pass -1 for that param. - void AddOp(tensorflow::StringPiece op_name, - tensorflow::StringPiece short_name, - tensorflow::StringPiece category, int64 cycles, int64 flop_count, + void AddOp(absl::string_view op_name, absl::string_view short_name, + absl::string_view category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { op_infos_.push_back({std::string(op_name), std::string(short_name), diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index aa325dc8a3..85bb4a8b24 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface { ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "implicit-broadcast-remover"; } diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d522..df88587492 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { + public: + ImplicitBroadcastRemoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: ImplicitBroadcastRemover remover_; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 256c8e5573..43ef30d1eb 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -17,12 +17,13 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -33,32 +34,30 @@ using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using absl::StrJoin; using tensorflow::gtl::ArraySlice; -using tensorflow::str_util::Join; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { auto* unknown_tensor = root->as<UnknownArray>(); - return tensorflow::strings::StrCat("%", - unknown_tensor->instruction().name()); + return absl::StrCat("%", unknown_tensor->instruction().name()); } case Array::kConstant: { if (print_constants) { string contents = root->as<ConstantArray>()->literal()->ToString(); - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, - ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + " ", contents, ")"); } - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + ")"); } case Array::kReshaped: { ReshapedArray* reshaped_array = root->as<ReshapedArray>(); - return tensorflow::strings::StrCat( + return absl::StrCat( "(reshape ", ToString(reshaped_array->operand(), print_constants), " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); } @@ -69,11 +68,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { string name = root->kind() == Array::kScalarIndexedConstant ? "scalar-indexed-const" : "scalar-indexed"; - return tensorflow::strings::StrCat( + return absl::StrCat( "(", name, " ", ToString(indexed_array->source(), print_constants), " ", ToString(indexed_array->indices(), print_constants), " ", indexed_array->source_dim(), "->[", - Join(indexed_array->output_dims(), ","), "])"); + StrJoin(indexed_array->output_dims(), ","), "])"); } } } @@ -379,8 +378,8 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( CHECK_NE(candidate_operand_dim, 0) << "result_dim = " << result_dim << ", result_subarray_size = " << result_subarray_size - << ", result_shape = [" << Join(result_shape, ",") << "]" - << ", operand_shape = [" << Join(operand_shape, ",") << "]"; + << ", result_shape = [" << StrJoin(result_shape, ",") << "]" + << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -396,12 +395,13 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( std::vector<string> result_strings; absl::c_transform(result, std::back_inserter(result_strings), [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat( - value.result_dim, "->", value.operand_dim); + return absl::StrCat(value.result_dim, "->", + value.operand_dim); }); - VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" - << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; + VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" + << StrJoin(result_shape, ",") << "] passthrough indices are [" + << StrJoin(result_strings, ",") + << "] (legend: `result`->`operand`)"; } DCHECK(absl::c_is_sorted( @@ -443,7 +443,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape, ArraySlice<int64> result_shape, int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" - << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; int64 indexed_source_subarray_size = @@ -755,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( if (source_dim_for_new_scalar_indexed_node == -1) { VLOG(3) << "Could not compute the source dim for the new scalar indexed " "node: scalar_indexed_source_shape = [" - << Join(scalar_indexed_source_shape.dimensions(), ",") + << StrJoin(scalar_indexed_source_shape.dimensions(), ",") << "] and new_scalar_indexed_source_shape = [" - << Join(new_scalar_indexed_source_shape, ",") << "]"; + << StrJoin(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -997,8 +997,7 @@ absl::optional<int64> GetOnlyNonContractingNonBatchDim( // `contracting_dims` and `batch_dims` are the contracting and batch dimensions // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( - tensorflow::StringPiece tag, - Analysis::ScalarIndexedConstantArray* indexed_array, + absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) { absl::optional<int64> non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), @@ -1135,7 +1134,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( return nullptr; } -tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { +absl::string_view IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 675eb31d26..3fa7d749e1 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -371,7 +371,7 @@ class IndexedArrayAnalysis { // unconditionally add to the regular HLO pass pipeline. class IndexedArrayAnalysisPrinterPass : public HloPassInterface { public: - tensorflow::StringPiece name() const override; + absl::string_view name() const override; StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 97052edf7d..c34c32f7d3 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -22,6 +22,11 @@ limitations under the License. namespace xla { namespace { class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + public: + IndexedArrayAnalysisTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -634,9 +639,9 @@ ENTRY main { AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( (scalar-indexed-const (constant f32[3,4] f32[3,4] { - { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, - { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, - { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } + { 0.761594, 0.964028, 0.995055, 0.999329 }, + { 0.761594, 0.995055, 0.964028, 0.999329 }, + { 0.999329, 0.995055, 0.964028, 0.761594 } }) %indices 0->[0]))"); } diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h index a523811f6c..efa8ed3abc 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/inliner.h @@ -27,7 +27,7 @@ namespace xla { class Inliner : public HloPassInterface { public: ~Inliner() override = default; - tensorflow::StringPiece name() const override { return "inline"; } + absl::string_view name() const override { return "inline"; } // Run inlining on the given computation. Returns whether the computation was // changed. diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f73ca9adf7..8489c3d9ad 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface { bool may_duplicate = true) : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} ~InstructionFusion() override = default; - tensorflow::StringPiece name() const override { return "fusion"; } + absl::string_view name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index c75bffc63d..5741864282 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -27,6 +27,8 @@ limitations under the License. #include <tuple> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -49,20 +51,12 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { -// For now moving only one API here, but we should have a single top level -// anonymous namespace, instead of three or four spread all over this file. -namespace { - -} // namespace - std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -368,31 +362,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const { string LayoutConstraints::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", - computation_->name(), ":\n"); + absl::StrAppend(&output, "LayoutConstraints for computation ", + computation_->name(), ":\n"); for (auto* instruction : computation_->MakeInstructionPostOrder()) { - tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), - "\n"); + absl::StrAppend(&output, " ", instruction->ToShortString(), "\n"); for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { - tensorflow::strings::StrAppend( - &output, " operand (", i, - "): ", OperandLayout(instruction, i)->ToString(), "\n"); + absl::StrAppend(&output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { if (BufferLayout(*buffer) != nullptr) { - tensorflow::strings::StrAppend( - &output, " ", buffer->ToString(), " : ", - LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); + absl::StrAppend(&output, " ", buffer->ToString(), " : ", + LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); } } } if (ResultLayout() != nullptr) { - tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), - "\n"); + absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n"); } return output; } @@ -909,7 +899,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str(), + absl::StrJoin(index, ",").c_str(), buffer->ToString().c_str(), ShapeUtil::HumanStringWithLayout(instruction_subshape) .c_str(), @@ -1400,8 +1390,8 @@ StatusOr<Layout> InferArrayLayout( return FailedPrecondition( "Array at index {%s} in instruction %s aliases buffers %s " "and %s which have different layouts", - tensorflow::str_util::Join(index, ",").c_str(), - instruction->name().c_str(), source_buffers[0]->ToString().c_str(), + absl::StrJoin(index, ",").c_str(), instruction->name().c_str(), + source_buffers[0]->ToString().c_str(), source_buffer->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index f9e8dbea2f..3e000ec2df 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -297,7 +297,7 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} - tensorflow::StringPiece name() const override { return "layout-assignment"; } + absl::string_view name() const override { return "layout-assignment"; } // Assign layouts to the given module. Returns whether the module was changed // (any layouts were changed). diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a16fa75e30..6d05fa5fe2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase { EXPECT_IS_OK(layout_assignment.Run(module).status()); } - std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) { + std::vector<int64> LayoutOf(HloModule* module, absl::string_view name) { auto minor_to_major = FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector<int64>(minor_to_major.begin(), minor_to_major.end()); diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 539a9522c1..fc3289f30d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -69,6 +70,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -89,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -104,6 +107,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -192,6 +196,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm//:core", ], @@ -219,7 +224,7 @@ cc_library( deps = [ ":llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -230,6 +235,7 @@ cc_library( hdrs = ["buffer_assignment_util.h"], deps = [ "//tensorflow/compiler/xla/service:buffer_assignment", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index fe9eab93aa..8d9fa99d82 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace llvm_ir { diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index 4eb5d9fb47..bdce4a171b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "absl/strings/str_cat.h" namespace xla { namespace llvm_ir { @@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName( c = '_'; } } - return tensorflow::strings::StrCat("buffer_for_", instr_name); + return absl::StrCat("buffer_for_", instr_name); } const Literal& LiteralForConstantAllocation( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 27fbb11e2e..ad350613dd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* b) { + absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. @@ -101,8 +101,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice<IrArray> operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b) { + const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 3502577d23..e1631a62ae 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -65,8 +65,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // modify the input/output buffer without touching any of the other elements. Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice<IrArray> operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 2b6caee6aa..6971220022 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -342,9 +342,9 @@ llvm::Value* IrArray::Index::Linearize( return logical_linear_index; } -llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { +llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, + llvm::IRBuilder<>* b, + absl::string_view name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value @@ -402,7 +402,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { + absl::string_view name) const { llvm::Value* element_address = EmitArrayElementAddress(index, b, name); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index cbfd2e7012..e913c109b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -20,12 +20,12 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -241,7 +241,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -255,7 +255,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index b79567369a..bd0139f85b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<Status(llvm::Value*, bool)>& for_body_generator) { return If(b_->CreateICmpSLT(start, end), [&]() -> Status { @@ -30,7 +30,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function<Status(llvm::Value*, llvm::Value*)>& for_body_generator) { @@ -56,7 +56,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::If( - tensorflow::StringPiece name, llvm::Value* condition, + absl::string_view name, llvm::Value* condition, const std::function<Status()>& true_block_generator, const std::function<Status()>& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); @@ -70,7 +70,7 @@ Status KernelSupportLibrary::If( void KernelSupportLibrary::EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, + absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function<void(KernelSupportLibrary::ArgumentVector)>& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index c5354a8c42..b152cf9275 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -18,12 +18,12 @@ limitations under the License. #include <string> +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { // A thin wrapper around llvm_loop.h to make code generating structured control @@ -49,13 +49,13 @@ class KernelSupportLibrary { // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<Status(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator); void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { @@ -67,7 +67,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + Status For(absl::string_view name, int64 start, int64 end, int64 step, const std::function<Status(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { @@ -77,7 +77,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), @@ -99,13 +99,13 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function<Status(llvm::Value* ind_var, llvm::Value* is_first_iteration)>& for_body_generator); - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function<void(llvm::Value* ind_var, @@ -119,7 +119,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function<Status(llvm::Value* ind_var, llvm::Value* is_first_iteration)>& @@ -129,7 +129,7 @@ class KernelSupportLibrary { peel_first_iteration, for_body_generator); } - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function<void(llvm::Value* ind_var, llvm::Value* is_first_iteration)>& @@ -140,7 +140,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { return For(name, start, end, step, @@ -151,7 +151,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { ForReturnVoid(name, start, end, step, @@ -162,8 +162,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, @@ -173,8 +172,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { ForReturnVoid(name, start, end, llvm::ConstantInt::get(start->getType(), step), @@ -182,7 +180,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { return For(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -190,7 +188,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -203,7 +201,7 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(tensorflow::StringPiece name, llvm::Value* condition, + Status If(absl::string_view name, llvm::Value* condition, const std::function<Status()>& true_block_generator, const std::function<Status()>& false_block_generator = []() -> Status { return Status::OK(); }); @@ -222,7 +220,7 @@ class KernelSupportLibrary { IfReturnVoid("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + void IfReturnVoid(absl::string_view name, llvm::Value* condition, const std::function<void()>& true_block_generator, const std::function<void()>& false_block_generator = []() { }) { @@ -259,13 +257,13 @@ class KernelSupportLibrary { // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, ArgumentVector arguments, + absl::string_view kernel_name, ArgumentVector arguments, const std::function<void(ArgumentVector)>& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& kernel_body_generator) { @@ -278,7 +276,7 @@ class KernelSupportLibrary { static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*, llvm::Value*)>& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index ba7f94834c..978fa5b453 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -18,6 +18,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/strings/str_cat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -25,14 +26,13 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, +ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) @@ -46,9 +46,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, - UnrollMode unroll_mode, bool prevent_vectorization) { + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, + bool prevent_vectorization) { std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); @@ -168,16 +168,16 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { return result; } -string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { +string ForLoop::GetQualifiedName(absl::string_view name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } -llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, +llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b) { return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, +std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, @@ -186,12 +186,9 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, unroll_mode, prevent_vectorization); } -std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* stride, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr<ForLoop> ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); @@ -216,7 +213,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -227,7 +224,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -238,7 +235,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { std::vector<int64> dimensions(ShapeUtil::Rank(shape)); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); @@ -246,14 +243,14 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension))); index[dimension] = loop->GetIndVarValue(); } return index; @@ -261,7 +258,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix) { + absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost // loops are added first. Add loops in major-to-minor order, and skip the // 'dimension_to_skip' dimension. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index a4fed5c8dc..62aa15fe2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -19,15 +19,15 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -78,7 +78,7 @@ class ForLoop { // `unroll_mode` specifies the desired LLVM unrolling behavior for generated // loop. static std::unique_ptr<ForLoop> EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -133,19 +133,18 @@ class ForLoop { // Allow ForLoopNest to call this private constructor. friend class ForLoopNest; - ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, + ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* b); - llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. - string GetQualifiedName(tensorflow::StringPiece name); + string GetQualifiedName(absl::string_view name); // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. @@ -182,7 +181,7 @@ class ForLoopNest { SetIndexType(index_ty); } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, + ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), @@ -197,14 +196,14 @@ class ForLoopNest { // been added then emit loop inside the body of the last added loop. // unroll_mode is used to emit metadata that controls LLVM unrolling. std::unique_ptr<ForLoop> AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr<ForLoop> AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -213,13 +212,13 @@ class ForLoopNest { // end index are constant. std::unique_ptr<ForLoop> AddLoop( int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr<ForLoop> AddLoop( - int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + int64 start_index, int64 end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -234,8 +233,7 @@ class ForLoopNest { // within the shape. One possible order for that sequence would be: // // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) - IrArray::Index AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix); + IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix); // Add a loop for each dimension in "dimensions". "suffix" is the // name suffix of the indvar and basic blocks in this new loop nest. @@ -245,7 +243,7 @@ class ForLoopNest { // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::StringPiece suffix); + absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops // are constructed in major to minor dimension layout order. No loop is @@ -256,7 +254,7 @@ class ForLoopNest { // basic blocks) constructed by this method. IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix); + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e6126881af..f0db2a3761 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -61,7 +61,7 @@ string AsString(const std::string& str) { return string(str.data(), str.length()); } -llvm::StringRef AsStringRef(tensorflow::StringPiece str) { +llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } @@ -262,15 +262,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment) { return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment) { +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment) { llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), @@ -285,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( /*Context=*/b->getContext(), @@ -294,27 +296,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, /*InsertBefore*/ insert_before); } -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; if_data.if_block = b->GetInsertBlock(); if_data.true_block = - CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); + CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-false"), b) + emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { b->SetInsertPoint(if_data.if_block); - if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), b); + if_data.after_block = + CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b); b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), - AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); } // Our basic block should now end with an unconditional branch. Remove it; @@ -413,14 +413,14 @@ string IrName(string a) { return a; } -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) { +string IrName(absl::string_view a, absl::string_view b) { if (!a.empty() && !b.empty()) { - return IrName(tensorflow::strings::StrCat(a, ".", b)); + return IrName(absl::StrCat(a, ".", b)); } - return IrName(tensorflow::strings::StrCat(a, b)); + return IrName(absl::StrCat(a, b)); } -string IrName(const HloInstruction* a, tensorflow::StringPiece b) { +string IrName(const HloInstruction* a, absl::string_view b) { return IrName(a->name(), b); } @@ -556,7 +556,7 @@ std::map<int, llvm::MDNode*> MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { +static string GetProcessUniqueIrFileName(absl::string_view prefix) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); @@ -584,18 +584,16 @@ Status DumpIRToDirectory(const string& directory_name, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. string unique_and_safe_file_name = GetProcessUniqueIrFileName( - tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); + absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); string ir_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); // For some models the embedded constants can be huge, so also dump the module // with the constants stripped to get IR that is easier to manipulate. string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( directory_name, ir_file_name, DumpModuleToString(llvm_module))); @@ -607,8 +605,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module) { + absl::string_view name, llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -638,7 +635,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { + if (!absl::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 0958398534..dde50e19d1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" @@ -47,11 +47,11 @@ namespace llvm_ir { // Convert a std::string (used by LLVM's interfaces) to string. string AsString(const std::string& str); -// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both -// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a +// Convert a absl::string_view to a llvm::StringRef. Note: both +// absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(tensorflow::StringPiece str); +llvm::StringRef AsStringRef(absl::string_view str); template <typename T> llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) { @@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module); // - removing all '%'s. // string IrName(string a); -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b); -string IrName(const HloInstruction* a, tensorflow::StringPiece b = ""); +string IrName(absl::string_view a, absl::string_view b); +string IrName(const HloInstruction* a, absl::string_view b = ""); // Removes special characters from a function name. // @@ -164,21 +164,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // This can be useful to avoid e.g. executing an alloca every time // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment = 0); +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created @@ -210,7 +212,7 @@ struct LlvmIfData { // Currently the insertion point of the builder must be a well-formed // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, @@ -285,8 +287,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module); + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 36f5fa1952..cf7445804c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. @@ -122,7 +122,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, +Status LoopEmitter::EmitLoop(absl::string_view loop_name, llvm::Type* index_type) { if (index_type == nullptr) { index_type = b_->getInt64Ty(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index c4f5c82086..57d9d8bbc6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -69,10 +69,10 @@ class LoopEmitter { } virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = "", + Status EmitLoop(absl::string_view loop_name = "", llvm::Type* index_type = nullptr); protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index c333311a7e..00dd3f1638 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -88,7 +88,7 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const absl::optional<IrArray>& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 39fffea931..527ed10374 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -32,7 +32,7 @@ namespace llvm_ir { // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const absl::optional<IrArray>& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index b7cb782a7e..ea59adadea 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index c742d35a7b..e1f56727bd 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { string color_string; if (has_color()) { - color_string = tensorflow::strings::StrCat(" @", color().value()); + color_string = absl::StrCat(" @", color().value()); } - return tensorflow::strings::StrCat(instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "](#", id(), color_string, ")"); + return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","), + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 6aa639a954..4c8cb7d379 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,10 +19,10 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} - tensorflow::StringPiece name() const override { - return "multi_output_fusion"; - } + absl::string_view name() const override { return "multi_output_fusion"; } // Run multi-output fusion on the given module. Returns whether the module // was changed. diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f6e7578a89..70cd0a339a 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -52,7 +53,7 @@ NameUniquer::NameUniquer(const string& separator) { return result; } -string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { +string NameUniquer::GetUniqueName(absl::string_view prefix) { string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in @@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); + } else { + // absl::SimpleAtoi may modify numeric_suffix even if it returns false. + numeric_suffix = 0; } } SequentialIdGenerator& id_generator = generated_names_[root]; numeric_suffix = id_generator.RegisterId(numeric_suffix); if (numeric_suffix == 0) { - return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) - : root; + return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root; } - tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + absl::StrAppend(&root, separator_, numeric_suffix); return root; } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4423d61069..6dd89c240f 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,8 +18,8 @@ limitations under the License. #include <string> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -38,7 +38,7 @@ class NameUniquer { // Get a sanitized unique name in a string, with an optional prefix for // convenience. - string GetUniqueName(tensorflow::StringPiece prefix = ""); + string GetUniqueName(absl::string_view prefix = ""); // Sanitizes and returns the name. Unallowed characters will be replaced with // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ac6ea4c72f..ccc06ce613 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -622,7 +622,7 @@ template <typename Previous> class HloInstructionPatternNameImpl { public: explicit HloInstructionPatternNameImpl(const Previous& previous, - tensorflow::StringPiece name) + absl::string_view name) : previous_(previous), name_(name) {} bool Match(const ::xla::HloInstruction* inst) const { @@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl { private: Previous previous_; - tensorflow::StringPiece name_; + absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -784,7 +784,7 @@ class HloInstructionPattern { // Modifies the pattern to match only if the instruction has the given name. HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>> - WithName(tensorflow::StringPiece name) const { + WithName(absl::string_view name) const { return HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>( HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_); diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 39fe3c7835..150af0cd93 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -19,20 +19,19 @@ limitations under the License. #include <string> #include <utility> +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { -using tensorflow::str_util::Lowercase; - // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; @@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter"; namespace { string CanonicalPlatformName(const string& name) { - string platform_str = Lowercase(name); + string platform_str = absl::AsciiStrToLower(name); // "cpu" and "host" mean the same thing. if (platform_str == "cpu") { platform_str = "host"; @@ -94,7 +93,7 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( @@ -110,15 +109,15 @@ PlatformUtil::GetSupportedPlatforms() { return platforms[0]; } else if (platforms.size() == 2) { for (int i = 0; i < 2; i++) { - if (Lowercase(platforms[i]->Name()) == kInterpreter && - Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter && + absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) { return platforms[1 - i]; } } } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( @@ -132,7 +131,7 @@ PlatformUtil::GetSupportedPlatforms() { string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) == platform_str) { + if (absl::AsciiStrToLower(platform->Name()) == platform_str) { return platform; } } @@ -146,7 +145,7 @@ PlatformUtil::GetSupportedPlatforms() { TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); std::vector<se::Platform*> matched; for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) != platform_name) { + if (absl::AsciiStrToLower(platform->Name()) != platform_name) { matched.push_back(platform); } } @@ -157,7 +156,7 @@ PlatformUtil::GetSupportedPlatforms() { if (matched.size() == 1) { return matched[0]; } - string matched_string = tensorflow::str_util::Join( + string matched_string = absl::StrJoin( matched, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index afde3cf95c..256b231e3a 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface { ~ReducePrecisionInsertion() override{}; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "reduce-precision-insertion"; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1f59e3b314..1e86a0823a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-mover"; } + absl::string_view name() const override { return "reshape-mover"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 7534a3f7e3..a395dd5333 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -28,13 +28,18 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloVerifiedTestBase; + +namespace op = xla::testing::opcode_matchers; + +class ReshapeMoverTest : public HloVerifiedTestBase { + public: + ReshapeMoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 8f735e877d..14f062c89c 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -22,7 +22,7 @@ namespace xla { class ScatterExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "scatter_expander"; } + absl::string_view name() const override { return "scatter_expander"; } StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 18d1b7732b..d39a5191b8 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" @@ -46,7 +47,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -55,8 +55,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/ptr_util.h" +using absl::StrCat; using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrCat; namespace xla { diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec6aa6df55..6a22f8bef4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,6 +22,9 @@ limitations under the License. #include <string> #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -29,28 +32,24 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -using tensorflow::str_util::Join; -using tensorflow::strings::Printf; - namespace xla { - namespace { +using absl::StrJoin; +using tensorflow::strings::Printf; + // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) { return std::set<int64>(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { +Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!ShapeUtil::IsArray(shape)) { return InvalidArgument("Expected array argument for %s, but got %s.", std::string(op_type).c_str(), @@ -234,10 +233,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kRoundNearestAfz: if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating for floor/ceil " - "operation; got %s.", + "Expected element type in shape to be floating for %s operation; " + "got %s.", + HloOpcodeString(opcode).c_str(), PrimitiveType_Name(shape.element_type()).c_str()); } return shape; @@ -251,8 +252,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "Expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s.", + "Expected element type in shape to be floating or complex for %s " + "operation; got %s.", + HloOpcodeString(opcode).c_str(), PrimitiveType_Name(shape.element_type()).c_str()); } return shape; @@ -265,19 +267,51 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, } else { return InvalidArgument( "Expected element type in shape to be floating or complex for " - "real/imag operation; got %s.", + "%s operation; got %s.", + HloOpcodeString(opcode).c_str(), PrimitiveType_Name(shape.element_type()).c_str()); } case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( shape, primitive_util::ComplexComponentType(shape.element_type())); + } else if (ShapeUtil::ElementIsSigned(shape)) { + return shape; + } else { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode).c_str(), + PrimitiveType_Name(shape.element_type()).c_str()); } - return shape; case HloOpcode::kClz: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to Clz " + "operation; got %s.", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return shape; case HloOpcode::kNegate: - case HloOpcode::kRoundNearestAfz: + if (!ShapeUtil::ElementIsIntegral(shape) && + !ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be integral, floating or " + "complex for %s operation; got %s.", + HloOpcodeString(opcode).c_str(), + PrimitiveType_Name(shape.element_type()).c_str()); + } + return shape; case HloOpcode::kSign: + if (!ShapeUtil::ElementIsSigned(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be signed or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode).c_str(), + PrimitiveType_Name(shape.element_type()).c_str()); + } return shape; case HloOpcode::kNot: @@ -879,16 +913,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), - Join(broadcast_dimensions, ", ").c_str()); + StrJoin(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR( - ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", - HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR( - ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", - HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1059,7 +1091,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s.", - Join(pieces, ", ").c_str()); + StrJoin(pieces, ", ").c_str()); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1076,7 +1108,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers; got: %s.", - Join(dimensions, ", ").c_str()); + StrJoin(dimensions, ", ").c_str()); } } @@ -1977,14 +2009,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "%s in slice operation; argument shape: %s; starts: {%s}; limits: " "{%s}; strides: {%s}.", message.c_str(), ShapeUtil::HumanString(arg).c_str(), - Join(starts, ",").c_str(), Join(limits, ",").c_str(), - Join(strides, ",").c_str()); + StrJoin(starts, ",").c_str(), StrJoin(limits, ",").c_str(), + StrJoin(strides, ",").c_str()); }; TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), - Join(limits, ", ").c_str()); + ShapeUtil::HumanString(arg).c_str(), StrJoin(starts, ", ").c_str(), + StrJoin(limits, ", ").c_str()); if (starts.size() != limits.size()) { return error(Printf("slice start and limit sizes differ: %zu vs %zu", @@ -2047,7 +2079,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", ShapeUtil::HumanString(operand_shape).c_str(), ShapeUtil::HumanString(start_indices_shape).c_str(), - Join(slice_sizes, ", ").c_str()); + StrJoin(slice_sizes, ", ").c_str()); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( @@ -2344,7 +2376,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); + StrJoin(dimensions, ",").c_str(), + ShapeUtil::HumanString(operand).c_str()); } return inferred_shape; @@ -2464,8 +2497,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); string argument_shapes = - Join(arg_shapes, ", ", [](string* out, const Shape* shape) { - tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) { + absl::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " @@ -2498,14 +2531,14 @@ static Status ValidateGatherDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ").c_str()); } if (absl::c_adjacent_find(dim_numbers.offset_dims()) != dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ").c_str()); } const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); @@ -2554,7 +2587,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.start_index_map(), ", ").c_str()); + StrJoin(dim_numbers.start_index_map(), ", ").c_str()); } for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { @@ -2569,7 +2602,7 @@ static Status ValidateGatherDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( "collapsed_slice_dims in gather op must be sorted; got: %s", - Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != @@ -2577,7 +2610,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } return Status::OK(); @@ -2639,8 +2672,9 @@ static Status ValidateGatherDimensionNumbers( "All components of the offset index in a gather op must either be a " "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " "output_slice_sizes=%s, collapsed_slice_dims=%s.", - slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(), - Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); + slice_sizes.size(), + StrJoin(gather_dim_numbers.offset_dims(), ",").c_str(), + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); } for (int i = 0; i < slice_sizes.size(); i++) { @@ -2703,13 +2737,13 @@ Status ValidateScatterDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ").c_str()); } if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ").c_str()); } const int64 updates_rank = ShapeUtil::Rank(updates_shape); for (int64 window_dim : dim_numbers.update_window_dims()) { @@ -2725,13 +2759,13 @@ Status ValidateScatterDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ").c_str()); } if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ").c_str()); } for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { @@ -2773,7 +2807,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " "got: %s.", - Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 70714ffff0..5c12dc37b7 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -76,7 +77,7 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = tensorflow::strings::StrCat( + string s = absl::StrCat( "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), ", on-device shape=" + diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index e0f995fd0d..0c577ec67a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -28,7 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/notification.h" -using ::tensorflow::strings::StrCat; +using absl::StrCat; namespace xla { /* static */ tensorflow::mutex diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 71e8446452..3e5aa2db60 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface { explicit TransposeFolding( TransposableGemmOperandsFn transposable_gemm_operands, TransposableConvOperandsFn transposable_conv_operands); - tensorflow::StringPiece name() const override { return "transpose-folding"; } + absl::string_view name() const override { return "transpose-folding"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0c2f2112af..cb07b8d4d3 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,17 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "])"); + return absl::StrCat("BufferAlias(", instruction_->name(), "[", + absl::StrJoin(index_, ","), "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -496,8 +495,7 @@ StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt( if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { return FailedPrecondition( "instruction %s does not define buffer at index {%s}", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + instruction->name().c_str(), absl::StrJoin(index, ",").c_str()); } return buffers[0]; } @@ -563,8 +561,7 @@ string TuplePointsToAnalysis::ToString() const { for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = computation == module_->entry_computation() ? "entry " : ""; - tensorflow::strings::StrAppend(&output, entry, "computation ", - computation->name(), ":\n"); + absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); @@ -576,12 +573,11 @@ string TuplePointsToAnalysis::ToString() const { } } - tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); + absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { - tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n"); + absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { - tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(), - "\n"); + absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } return output; @@ -590,20 +586,18 @@ string TuplePointsToAnalysis::ToString() const { void TuplePointsToAnalysis::InstructionToString( const HloInstruction* instruction, string* output) const { const string prefix = instruction->IsFused() ? " " : ""; - tensorflow::strings::StrAppend(output, prefix, " instruction ", - instruction->ToShortString(), ":\n"); + absl::StrAppend(output, prefix, " instruction ", + instruction->ToShortString(), ":\n"); const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement([&prefix, &output]( const ShapeIndex& index, const PointsToSet::BufferList& points_to) { - tensorflow::strings::StrAppend( - output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ", - tensorflow::str_util::Join( - points_to, ", ", - [](string* out, const LogicalBuffer* source) { - out->append(source->ToString()); - }), - "\n"); + absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ", + absl::StrJoin(points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); }); } diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 7509501883..8c91d6e69d 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface { TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} - tensorflow::StringPiece name() const override { return "tuple-simplifier"; } + absl::string_view name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 21fb8568a8..2dba7d7f75 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface { public: ~WhileLoopConstantSinking() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8e6cc87875..2cdf20ce80 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface { : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b..e14014b961 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers; class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { public: + WhileLoopInvariantCodeMotionTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index a24e2b0116..6a7bfe3f12 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -236,12 +236,11 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" - << tensorflow::str_util::Join( - user->users(), ", ", - [&](string* out, const HloInstruction* instr) { - tensorflow::strings::StrAppend( - out, instr->ToString(print_no_metadata)); - }) + << absl::StrJoin(user->users(), ", ", + [&](string* out, const HloInstruction* instr) { + absl::StrAppend( + out, instr->ToString(print_no_metadata)); + }) << "}"; replacements.emplace(user, nullptr); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 3d3e1d60f2..78024f14dc 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -33,9 +33,7 @@ namespace xla { class WhileLoopSimplifier : public HloPassInterface { public: ~WhileLoopSimplifier() override {} - tensorflow::StringPiece name() const override { - return "simplify-while-loops"; - } + absl::string_view name() const override { return "simplify-while-loops"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 2e1571943e..cfe4104f6d 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -27,6 +28,11 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { + public: + WhileLoopSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); @@ -64,10 +70,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } @@ -103,10 +107,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 52d9c3e5ae..e8f76ff745 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,15 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; static StatusOr<HloComputation*> WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 8763e588c4..a7f0e207eb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -24,7 +24,7 @@ namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: StatusOr<bool> Run(HloModule* module) override; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "zero_sized_hlo_elimination"; } }; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7244be80d9..31ddd57eef 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,13 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -31,25 +38,22 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { - return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", absl::StrJoin(indices_, ","), "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { @@ -449,14 +453,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( namespace { // Class to memoize the computation of -// tensorflow::str_util::Lowercase(PrimitiveType_Name(p)) +// absl::AsciiStrToLower(PrimitiveType_Name(p)) // for all PrimitiveType values "p" class PrimitiveTypeNameGenerator { public: PrimitiveTypeNameGenerator() { for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = tensorflow::str_util::Lowercase( + lowercase_name_[i] = absl::AsciiStrToLower( PrimitiveType_Name(static_cast<PrimitiveType>(i))); } } @@ -507,7 +511,7 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) { return text; } return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - tensorflow::str_util::Join(shape.dimensions(), ","), "]"); + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -543,30 +547,30 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) { : "(unknown)", ": ", HumanString(shape))); } - return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ", HumanString(program_shape.result())); } namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. -StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { - tensorflow::str_util::RemoveLeadingWhitespace(s); +StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) { + *s = StripLeadingAsciiWhitespace(*s); - if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. + if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector<Shape> shapes; bool must_end = false; while (true) { - if (tensorflow::str_util::ConsumePrefix(s, ")")) { + if (absl::ConsumePrefix(s, ")")) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", - std::string(*s).c_str()); + string(*s).c_str()); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !tensorflow::str_util::ConsumePrefix(s, ","); + *s = StripLeadingAsciiWhitespace(*s); + must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } @@ -575,9 +579,9 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { string dimensions_string; string format_string; string layout_string; - // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so + // absl::string_view is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our StringPiece type. + // amount from our string_view type. static LazyRE2 shape_pattern = { "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); @@ -585,12 +589,12 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); - auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> { + auto string_to_int64 = [&s](absl::string_view input) -> StatusOr<int64> { int64 element; - if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + if (!absl::SimpleAtoi(input, &element)) { return InvalidArgument( "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), std::string(*s).c_str()); + string(input).c_str(), string(*s).c_str()); } return element; }; @@ -598,7 +602,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { auto comma_list_to_int64s = [string_to_int64](const string& input) -> StatusOr<std::vector<int64>> { std::vector<int64> results; - for (const string& piece : tensorflow::str_util::Split(input, ',')) { + for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } @@ -645,16 +649,15 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { } return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(*s).c_str()); + string(*s).c_str()); } } // namespace -/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString( - tensorflow::StringPiece s) { +/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(absl::string_view s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(s).c_str()); + string(s).c_str()); } return shape; } @@ -1172,8 +1175,7 @@ Status ForEachMutableSubshapeHelper( CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) << "shape=" << HumanStringWithLayout(shape) << ", new_shape=" << HumanStringWithLayout(new_shape) - << ", permutation={" << tensorflow::str_util::Join(permutation, ",") - << "}"; + << ", permutation={" << absl::StrJoin(permutation, ",") << "}"; } return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index cb72fbbb0e..84f36e48a0 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -228,7 +228,7 @@ class ShapeUtil { // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. - static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s); + static StatusOr<Shape> ParseShapeString(absl::string_view s); // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index e5dd62ae9a..7549ba9c78 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include <numeric> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -23,8 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -849,13 +849,13 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { std::iota(layout.begin(), layout.end(), 0); do { Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout); - SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s))); + SCOPED_TRACE(absl::StrCat("s=", ShapeUtil::HumanString(s))); std::vector<int64> permutation(3); std::iota(permutation.begin(), permutation.end(), 0); do { - SCOPED_TRACE(tensorflow::strings::StrCat( - "permutation=", tensorflow::str_util::Join(permutation, ","))); + SCOPED_TRACE( + absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); // TransposeIsBitcast takes the inverse of the permutation that // PermuteDimensions takes. diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index a6b1f9004f..b88fe367d7 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -17,9 +17,8 @@ limitations under the License. #include <algorithm> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stacktrace.h" @@ -37,8 +36,7 @@ static void LogError(const Status& status, const char* filename, int line, if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) { string stack_trace; if (should_log_stack_trace) { - stack_trace = - tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace()); + stack_trace = absl::StrCat("\n", tensorflow::CurrentStackTrace()); } switch (log_severity) { case tensorflow::INFO: @@ -142,17 +140,15 @@ Status MakeErrorStream::Impl::GetStatus() { is_done_ = true; const string& stream_str = stream_.str(); - const string str = - prior_message_handling_ == kAppendToPriorMessage - ? tensorflow::strings::StrCat(prior_message_, stream_str) - : tensorflow::strings::StrCat(stream_str, prior_message_); + const string str = prior_message_handling_ == kAppendToPriorMessage + ? absl::StrCat(prior_message_, stream_str) + : absl::StrCat(stream_str, prior_message_); if (TF_PREDICT_FALSE(str.empty())) { - return MakeError(file_, line_, code_, - tensorflow::strings::StrCat( - str, "Error without message at ", file_, ":", line_), - true /* should_log */, - tensorflow::ERROR /* log_severity */, - should_log_stack_trace_); + return MakeError( + file_, line_, code_, + absl::StrCat(str, "Error without message at ", file_, ":", line_), + true /* should_log */, tensorflow::ERROR /* log_severity */, + should_log_stack_trace_); } else { return MakeError(file_, line_, code_, str, should_log_, log_severity_, should_log_stack_trace_); diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 8918350135..3ede5e6e38 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -19,9 +19,9 @@ limitations under the License. #include <list> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 6baf95d631..6b29d833da 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -43,6 +43,7 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], alwayslink = True, ) @@ -205,6 +206,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -391,6 +393,7 @@ xla_test( "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -557,6 +560,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -671,6 +675,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -689,7 +694,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -697,6 +701,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -746,7 +751,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -754,6 +758,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -829,7 +834,10 @@ xla_test( timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -839,7 +847,10 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -924,6 +935,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1001,6 +1013,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", ], ) @@ -1128,6 +1141,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1157,6 +1171,7 @@ xla_test_library( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1231,12 +1246,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1247,12 +1262,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1425,6 +1440,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1494,6 +1510,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1648,6 +1665,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1660,7 +1678,6 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -1671,6 +1688,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1851,13 +1869,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1866,6 +1880,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2026,6 +2041,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 84c5b6e549..577fd1ab3b 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -296,6 +296,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()}); } +XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { + XlaBuilder b(TestName()); + + std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)}; + std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + + std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)}; + std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + + Lt(lhs_param, rhs_param); + + ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 24b17b7100..ac90a3adb6 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -382,7 +382,7 @@ struct BatchNormTestParam { friend ::std::ostream& operator<<(::std::ostream& os, const BatchNormTestParam& p) { - os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, "; os << "feature_index=" << p.feature_index << ", "; os << "random_value_mean=" << p.random_value_mean << ", "; os << "random_value_var=" << p.random_value_var; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 2cab3264a7..9cd974fd9b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -196,8 +196,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, tensorflow::strings::StrCat( - "Test with output layout: ", + verify_output(*actual, + absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); return Status::OK(); @@ -258,7 +258,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( output_with_layout)); string error_message = "Test with input layouts: "; for (const auto& str : layout_strings) { - tensorflow::strings::StrAppend(&error_message, str, " "); + absl::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); return Status::OK(); @@ -391,7 +391,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } void ClientLibraryTestBase::ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, absl::string_view expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 24d0325929..ac96d3e325 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -202,7 +202,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. void ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, absl::string_view expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments); // Convenience method for running a built computation, transferring the diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 5a06d061f0..8226b6de3f 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/match.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -145,8 +145,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } @@ -161,8 +161,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 40658c3b77..d2c6478b02 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -35,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0e9e92ed99..5873516442 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -261,16 +262,14 @@ string PrintDotTestParam( const ::testing::TestParamInfo<DotTestParam>& test_param) { const DotTestParam& param = test_param.param; if (param.has_addend) { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F", - param.addend_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); } else { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); } } diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 39cc6c5927..4a835a8e21 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,13 +16,13 @@ limitations under the License. #include <limits> #include <string> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -39,8 +39,7 @@ class FloorCeilTest : public ClientLibraryTestBase { // Runs a computation and comparison on expected vs f(input) void TestR1F32(tensorflow::gtl::ArraySlice<float> input, tensorflow::gtl::ArraySlice<float> expected, Function f) { - LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") - << "}"; + LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}"; XlaBuilder builder(TestName()); auto c = ConstantR1<float>(&builder, input); if (f == kCeil) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5635c3fe86..93ea144438 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -43,7 +43,7 @@ namespace xla { namespace { using absl::optional; -using tensorflow::StringPiece; +using absl::string_view; using tensorflow::gtl::ArraySlice; constexpr char kInterpreter[] = "interpreter"; @@ -86,16 +86,20 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier) +HloTestBase::HloTestBase(bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = - absl::make_unique<HloVerifier>(allow_mixed_precision_in_hlo_verifier); + hlo_verifier_ = absl::make_unique<HloVerifier>( + /*layout_sensitive=*/verifier_layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); } std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) { @@ -239,7 +243,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompare( - const StringPiece hlo_string, const absl::optional<ErrorSpec>& error, + string_view hlo_string, const absl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -252,7 +256,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } -::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { +::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); if (!module_or_status.ok()) { @@ -289,7 +293,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - const StringPiece hlo_string, const absl::optional<ErrorSpec>& error, + string_view hlo_string, const absl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -316,7 +320,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } HloComputation* HloTestBase::FindComputation(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { auto computations = module->computations(); auto it = absl::c_find_if( computations, [&](HloComputation* c) { return c->name() == name; }); @@ -327,7 +331,7 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, } HloInstruction* HloTestBase::FindInstruction(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { for (const HloComputation* c : module->computations()) { auto instructions = c->instructions(); auto it = absl::c_find_if( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index d88abf561a..06bcc39741 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -85,12 +85,14 @@ class HloTestBase : public ::testing::Test { // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true); + HloTestBase(bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive = false, bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} @@ -169,18 +171,18 @@ class HloTestBase : public ::testing::Test { // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. ::testing::AssertionResult RunAndCompare( - const tensorflow::StringPiece hlo_string, + const absl::string_view hlo_string, const absl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; - ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + ::testing::AssertionResult Run(const absl::string_view hlo_string) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const absl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareNoHloPasses( - const tensorflow::StringPiece hlo_string, + const absl::string_view hlo_string, const absl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -228,10 +230,8 @@ class HloTestBase : public ::testing::Test { // // This is useful for tests which create HLOs from a string and then want to // inspect a particular computation or instruction. - HloComputation* FindComputation(HloModule* module, - tensorflow::StringPiece name); - HloInstruction* FindInstruction(HloModule* module, - tensorflow::StringPiece name); + HloComputation* FindComputation(HloModule* module, absl::string_view name); + HloInstruction* FindInstruction(HloModule* module, absl::string_view name); // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index a509ee3207..8f86c528d0 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -25,8 +25,11 @@ limitations under the License. namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(absl::make_unique<ShapeVerifier>()) {} +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -51,8 +54,7 @@ void HloVerifiedTestBase::TearDown() { } void HloVerifiedTestBase::VerifyModule(HloModule* module) { - HloVerifier verifier(/*allow_mixed_precision=*/true); - xla::StatusOr<bool> mutated = verifier.Run(module); + xla::StatusOr<bool> mutated = verifier().Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -73,7 +75,7 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { return modules_.back().get(); } -void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, +void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b28c01c36..cc6967feed 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -29,7 +29,8 @@ namespace xla { // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { protected: - HloVerifiedTestBase(); + explicit HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision); ~HloVerifiedTestBase() override; // Constructs a default shape verifier. @@ -44,32 +45,28 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); - // Sets the shape-size function used during hlo verification. If this isn't - // called, a default ShapeVerifier is used instead. - void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) { - shape_verifier_ = std::move(shape_verifier); - } - // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. HloModule* CreateNewModule(const string& name = TestName()); + private: + void VerifyModule(HloModule* module); + // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. - private: + // // Lazily populated. Access via module(). std::unique_ptr<HloModule> module_; // Populated by calls to CreateNewModule. std::vector<std::unique_ptr<HloModule>> modules_; - std::unique_ptr<ShapeVerifier> shape_verifier_; + bool tear_down_called_ = false; - static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index f297b2b847..4151bfae03 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -80,7 +80,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { std::vector<string> results; TF_CHECK_OK(env->GetMatchingPaths(pattern, &results)); - LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; + LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { LiteralProto literal_proto; @@ -105,8 +105,10 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(*expected, *actual); - EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); - EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index b6035a21a6..7956a034f8 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -158,7 +159,7 @@ class TestLinspaceMaxParametric string PrintTestLinspaceMaxParam( const ::testing::TestParamInfo<TestLinspaceMaxParam>& test_param) { const TestLinspaceMaxParam& param = test_param.param; - return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); + return absl::StrCat(param.rows, "r", param.cols, "c"); } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index cadf1c5523..16b77e965d 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -52,12 +53,22 @@ class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } + // Layout assignment assumes that there are no fusions in the input graph. + // Since the purpose of this test is to send pre-fused graphs to XLA, we have + // to do layout assignment ourselves. + DebugOptions GetDebugOptionsForTest() override { + auto opts = HloTestBase::GetDebugOptionsForTest(); + opts.add_xla_disable_hlo_passes("layout-assignment"); + return opts; + } + void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); - const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + const Shape elem_shape2 = + ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0}); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f))); @@ -100,10 +111,10 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); arg1.PopulateWithValue<float>(2.5f); - Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue<float>(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer(std::move(hlo_module), @@ -115,8 +126,10 @@ class MultiOutputFusionTest : public HloTestBase { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); - const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + const Shape elem_shape_F32 = + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); + const Shape elem_shape_U8 = + ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape_F32, "0")); auto param1 = builder.AddInstruction( @@ -136,12 +149,13 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {size, 1}), add)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, + dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -161,9 +175,9 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0(ShapeUtil::MakeShape(F32, {size})); + Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size})); input0.PopulateWithValue(2.5f); - Literal input1(ShapeUtil::MakeShape(F64, {size})); + Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); Literal expect = @@ -291,7 +305,7 @@ const char* const kScalarOps = R"( XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -323,7 +337,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -355,7 +369,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -388,7 +402,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -422,7 +436,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -457,7 +471,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -494,7 +508,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) init1 = f32[] parameter(1) @@ -529,7 +543,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { p0 = f16[2,2,2]{2,1,0} parameter(0) convert = f32[2,2,2]{2,1,0} convert(p0) diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index a080dd1732..9af9ea4a22 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,11 +15,11 @@ limitations under the License. #include <array> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -29,16 +29,13 @@ limitations under the License. namespace xla { namespace { -namespace str_util = tensorflow::str_util; -namespace strings = tensorflow::strings; - struct ReduceLayout { std::array<int64, 4> input_minor_to_major; std::array<int64, 3> output_minor_to_major; string ToString() const { - return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_", - str_util::Join(output_minor_to_major, "x")); + return absl::StrCat(absl::StrJoin(input_minor_to_major, "x"), "_", + absl::StrJoin(output_minor_to_major, "x")); } }; diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 531648fe3e..0916a07f4f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -57,8 +58,8 @@ static const int mantissa_sizes[] = {23, 10, 23, 10}; string TestDataToString(const ::testing::TestParamInfo<int> data) { int i = data.param; - return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_", - mantissa_sizes[i], "_mantissa_bits"); + return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i], + "_mantissa_bits"); } // The FPVAL macro allows us to write out the binary representation of the diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2065271a7f..b93d838349 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -559,9 +560,9 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) { *os << tensorflow::strings::Printf( "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(), spec.bounds.size() - spec.reduce_dims.size(), - tensorflow::str_util::Join(spec.bounds, "x").c_str(), - tensorflow::str_util::Join(spec.layout, "").c_str(), - tensorflow::str_util::Join(spec.reduce_dims, "").c_str()); + absl::StrJoin(spec.bounds, "x").c_str(), + absl::StrJoin(spec.layout, "").c_str(), + absl::StrJoin(spec.reduce_dims, "").c_str()); } // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index ebf7fa30be..60167619a4 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include <memory> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -579,21 +581,20 @@ string R4ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple<R4ReduceWindowTestData, bool>>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // - "__layout_", tensorflow::str_util::Join(param.layout, "_"), // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), // + "__pad_high_", absl::StrJoin(param.pad_high, "x"), // + "__layout_", absl::StrJoin(param.layout, "_"), // (param.reducer == kAdd) ? "_add" : "_max"); CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -935,15 +936,15 @@ string R3ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple<R3ReduceWindowTestData, bool>>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__padding_", param.padding == Padding::kSame ? "same" : "valid", - "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", + absl::StrJoin(param.window_bounds, "x"), "__strides_", + absl::StrJoin(param.strides, "x"), "__padding_", + param.padding == Padding::kSame ? "same" : "valid", "__layout_", + param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", + param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1069,17 +1070,16 @@ string R2ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple<R2ReduceWindowTestData, bool>>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__layout_", param.layout[0], "_", param.layout[1], // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_", + absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_", + param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1274,15 +1274,15 @@ string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple<R1ReduceWindowTestData, bool>>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = + absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), + "__strides_", absl::StrJoin(param.strides, "x"), + "__pad_low_", absl::StrJoin(param.pad_low, "x"), + "__pad_high_", absl::StrJoin(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 41e49b4003..60084f143d 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include <memory> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -43,10 +44,8 @@ struct ReverseSpec { string ToTestCaseName() const { return tensorflow::strings::Printf( - "reverse_%s_in_dims_%s_%s", - tensorflow::str_util::Join(input_dims, "x").c_str(), - tensorflow::str_util::Join(reversal, "x").c_str(), - use_bfloat16 ? "bf16" : "f32"); + "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x").c_str(), + absl::StrJoin(reversal, "x").c_str(), use_bfloat16 ? "bf16" : "f32"); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index e42c71eb28..cf2d453f43 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <limits> #include <memory> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index d865c414fd..c57bbbd1e4 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include <vector> #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -34,8 +36,6 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::str_util::Join; - class SliceTest : public ClientLibraryTestBase {}; TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { @@ -449,13 +449,11 @@ struct R4Spec { string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) { const R4Spec& spec = data.param; - return tensorflow::strings::StrCat( // - "input_", Join(spec.input_dims, "x"), // - "__layout_", Join(spec.input_layout, ""), // - "__starts_", Join(spec.slice_starts, "x"), // - "__limits_", Join(spec.slice_limits, "x"), // - "__strides_", Join(spec.slice_strides, "x") // - ); + return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"), + "__layout_", absl::StrJoin(spec.input_layout, ""), + "__starts_", absl::StrJoin(spec.slice_starts, "x"), + "__limits_", absl::StrJoin(spec.slice_limits, "x"), + "__strides_", absl::StrJoin(spec.slice_strides, "x")); } class SliceR4Test : public ClientLibraryTestBase, diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index be35ec6c6e..a9874a9186 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -20,7 +20,9 @@ limitations under the License. #include <string> #include <unordered_map> -#include "tensorflow/core/lib/strings/str_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" @@ -44,7 +46,7 @@ ManifestT ReadManifest() { string contents((std::istreambuf_iterator<char>(file_stream)), std::istreambuf_iterator<char>()); - std::vector<string> lines = tensorflow::str_util::Split(contents, '\n'); + std::vector<string> lines = absl::StrSplit(contents, '\n'); for (string& line : lines) { auto comment = line.find("//"); if (comment != string::npos) { @@ -53,8 +55,8 @@ ManifestT ReadManifest() { if (line.empty()) { continue; } - tensorflow::str_util::StripTrailingWhitespace(&line); - std::vector<string> pieces = tensorflow::str_util::Split(line, ' '); + absl::StripTrailingAsciiWhitespace(&line); + std::vector<string> pieces = absl::StrSplit(line, ' '); CHECK_GE(pieces.size(), 1); auto& platforms = manifest[pieces[0]]; for (int64 i = 1; i < pieces.size(); ++i) { @@ -73,8 +75,7 @@ string PrependDisabledIfIndicated(const string& test_case_name, // First try full match: test_case_name.test_name // If that fails, try to find just the test_case_name; this would disable all // tests in the test case. - auto it = manifest.find( - tensorflow::strings::StrCat(test_case_name, ".", test_name)); + auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); if (it == manifest.end()) { it = manifest.find(test_case_name); if (it == manifest.end()) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2f1d97b25d..21c58e075e 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -408,8 +408,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( return std::move(arguments); } -Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { - return HloVerifier(allow_mixed_precision).Run(module).status(); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision) { + return HloVerifier(/*layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision) + .Run(module) + .status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 1aca1d8ef7..277d53d423 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -95,8 +95,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(HloModule* const module, - bool allow_mixed_precision = false); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 2bdbd08309..c7eb9e2dbe 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include <array> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -84,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { "param")); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -101,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT(status.error_message(), ::testing::HasSubstr( diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 20ae68ab74..8f80a9f3e4 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -190,25 +190,6 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper<complex64>(); } -XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1<unsigned int>( - &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); - Abs(arg); - - ComputeAndCompareR1<unsigned int>( - &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {}); -} - -XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1<unsigned int>( - &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); - Sign(arg); - - ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {}); -} - XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); auto arg = ConstantR2<float>(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index e12e095ecd..6a7ddd9b55 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -82,8 +84,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results, - tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore = - {}) { + tensorflow::gtl::ArraySlice<absl::string_view> opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; @@ -100,7 +101,7 @@ Status ParseOneProfileOutputLine( string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; - string regexp_pattern = tensorflow::strings::StrCat( + string regexp_pattern = absl::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, match_bytes_per_cycle, separator, match_opcode); @@ -205,7 +206,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { rhs_shape); std::vector<string> profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; @@ -292,22 +293,20 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { matrix_shape); std::vector<string> profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); auto while_body_profile_start = - absl::c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith(s, - "Execution profile for body"); + absl::c_find_if(profile_output_lines, [](absl::string_view s) { + return absl::StartsWith(s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = - std::find_if(while_body_profile_start, profile_output_lines.end(), - [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith( - s, "********** microseconds report **********"); - }); + auto while_body_profile_end = std::find_if( + while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds report **********"); + }); // We emit a blank line before the "********** microseconds report **********" // line. diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a075195618..15603619b6 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -32,16 +32,14 @@ GTEST_API_ int main(int argc, char** argv) { // If the --benchmarks flag is passed in then only run the benchmarks, not the // tests. for (int i = 1; i < argc; i++) { - tensorflow::StringPiece arg(argv[i]); - if (arg == "--benchmarks" || - tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + absl::string_view arg(argv[i]); + if (arg == "--benchmarks" || absl::StartsWith(arg, "--benchmarks=")) { const char* pattern = nullptr; - if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + if (absl::StartsWith(arg, "--benchmarks=")) { pattern = argv[i] + strlen("--benchmarks="); } else { // Handle flag of the form '--benchmarks foo' (no '='). - if (i + 1 >= argc || - tensorflow::str_util::StartsWith(argv[i + 1], "--")) { + if (i + 1 >= argc || absl::StartsWith(argv[i + 1], "--")) { LOG(ERROR) << "--benchmarks flag requires an argument."; return 2; } diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 7de2c39b38..9835e3d803 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -21,24 +21,27 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath( - tensorflow::StringPiece path) { - CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) + absl::string_view path) { + CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr<tensorflow::RandomAccessFile> file; Status s = @@ -54,33 +57,6 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -namespace { -// This is an optimized version of tensorflow::str_util::Split which uses -// StringPiece for the delimited strings and uses an out parameter for the -// result to avoid vector creation/destruction. -void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim, - std::vector<tensorflow::StringPiece>* result) { - result->clear(); - - if (text.empty()) { - return; - } - - // The following loop is a little strange: its bound is text.size() + 1 - // instead of the more typical text.size(). - // The final iteration of the loop (when i is equal to text.size()) handles - // the trailing token. - size_t token_start = 0; - for (size_t i = 0; i < text.size() + 1; i++) { - if (i == text.size() || text[i] == delim) { - tensorflow::StringPiece token(text.data() + token_start, i - token_start); - result->push_back(token); - token_start = i + 1; - } - } -} -} // namespace - StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); @@ -90,11 +66,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() { return s; } - tensorflow::StringPiece sp(shape_string); - if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = std::string(sp); - shape_string = tmp; - } + absl::StripAsciiWhitespace(&shape_string); TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); if (shape.element_type() != F32) { return Unimplemented( @@ -105,35 +77,33 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() { auto result = absl::make_unique<Literal>(shape); const float fill = std::numeric_limits<float>::quiet_NaN(); result->PopulateWithValue<float>(fill); - std::vector<tensorflow::StringPiece> pieces; - std::vector<tensorflow::StringPiece> coordinates; + std::vector<absl::string_view> pieces; + std::vector<absl::string_view> coordinates; std::vector<int64> coordinate_values; string line; while (buf.ReadLine(&line).ok()) { - SplitByDelimToStringPieces(line, ':', &pieces); - tensorflow::StringPiece coordinates_string = pieces[0]; - tensorflow::StringPiece value_string = pieces[1]; - tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); - tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { + pieces = absl::StrSplit(line, ':'); + absl::string_view coordinates_string = + absl::StripAsciiWhitespace(pieces[0]); + absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]); + if (!absl::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); } - if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) { + if (!absl::ConsumeSuffix(&coordinates_string, ")")) { return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", line.c_str()); } float value; - if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), - &value)) { + if (!absl::SimpleAtof(absl::string_view(value_string), &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - std::string(value_string).c_str()); + string(value_string).c_str()); } - SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); + coordinates = absl::StrSplit(coordinates_string, ','); coordinate_values.clear(); - for (tensorflow::StringPiece piece : coordinates) { + for (absl::string_view piece : coordinates) { int64 coordinate_value; - if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { + if (!absl::SimpleAtoi(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", std::string(piece).c_str()); diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 708e8c80d8..b265640802 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,11 +18,11 @@ limitations under the License. #include <memory> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -41,8 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr<std::unique_ptr<Literal>> ReadPath( - tensorflow::StringPiece path); + static StatusOr<std::unique_ptr<Literal>> ReadPath(absl::string_view path); private: // Ownership of file is transferred. diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 24e0784741..00147015a6 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,23 +17,23 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" namespace xla { -/* static */ Status TextLiteralWriter::WriteToPath( - const Literal& literal, tensorflow::StringPiece path) { +/* static */ Status TextLiteralWriter::WriteToPath(const Literal& literal, + absl::string_view path) { std::unique_ptr<tensorflow::WritableFile> f; - auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(string(path), &f); if (!s.ok()) { return s; } @@ -51,11 +51,10 @@ namespace xla { if (!status.ok()) { return; } - string coordinates = tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(indices, ", "), ")"); + string coordinates = + absl::StrCat("(", absl::StrJoin(indices, ", "), ")"); - status = f_ptr->Append( - tensorflow::strings::StrCat(coordinates, ": ", value, "\n")); + status = f_ptr->Append(absl::StrCat(coordinates, ": ", value, "\n")); }); auto ignored = f->Close(); return status; diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 159ac1b7e1..34de8572d6 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -37,8 +37,7 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, absl::string_view path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 40d28a57bf..1e45588148 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -24,6 +24,7 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -191,6 +192,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index f0af0580c1..7aedd1da98 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -44,10 +44,9 @@ class OperationDumper : public DfsHloVisitorWithDefault { explicit OperationDumper(const string& path) : path_(path) {} Status DefaultAction(HloInstruction* hlo) override { - string params = tensorflow::str_util::Join( + string params = absl::StrJoin( hlo->operands(), ", ", [](string* out, const HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + absl::StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); // Spit `op_name(params...) -> result_type :: path` to stdout. std::cout << tensorflow::strings::Printf( diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index eb7bff053b..75b63c3b84 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -17,10 +17,10 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/platform/env.h" @@ -67,7 +67,7 @@ int main(int argc, char** argv) { floats.push_back(value); } - tensorflow::StringPiece content( + tensorflow::StringPiece content( // non-absl ok tensorflow::bit_cast<const char*>(floats.data()), floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e43498e381..85f05b7b8d 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,11 +18,13 @@ limitations under the License. #include <stdarg.h> #include <numeric> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -54,16 +56,16 @@ ScopedLoggingTimer::~ScopedLoggingTimer() { } } -Status AddStatus(Status prior, tensorflow::StringPiece context) { +Status AddStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat( - context, ": ", prior.error_message())}; + return Status{prior.code(), + absl::StrCat(context, ": ", prior.error_message())}; } -Status AppendStatus(Status prior, tensorflow::StringPiece context) { +Status AppendStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(), - ": ", context)}; + return Status{prior.code(), + absl::StrCat(prior.error_message(), ": ", context)}; } // Implementation note: we can't common these out (without using macros) because @@ -146,16 +148,13 @@ Status Unavailable(const char* format, ...) { return WithLogBacktrace(tensorflow::errors::Unavailable(message)); } -string Reindent(tensorflow::StringPiece original, - const tensorflow::StringPiece indentation) { - std::vector<string> pieces = tensorflow::str_util::Split( - tensorflow::StringPiece(original.data(), original.size()), '\n'); - return tensorflow::str_util::Join( - pieces, "\n", [indentation](string* out, string s) { - tensorflow::StringPiece piece(s); - tensorflow::str_util::RemoveWhitespaceContext(&piece); - tensorflow::strings::StrAppend(out, indentation, piece); - }); +string Reindent(absl::string_view original, + const absl::string_view indentation) { + std::vector<string> pieces = + absl::StrSplit(absl::string_view(original.data(), original.size()), '\n'); + return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) { + absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s)); + }); } bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) { @@ -234,20 +233,20 @@ bool HasInteriorPadding(const PaddingConfig& config) { namespace { string HumanReadableNumOps(double flops, double nanoseconds, - tensorflow::StringPiece op_prefix) { + absl::string_view op_prefix) { if (nanoseconds == 0) { - return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s"); + return absl::StrCat("NaN ", op_prefix, "OP/s"); } double nano_flops = flops / nanoseconds; string throughput = tensorflow::strings::HumanReadableNum( static_cast<int64>(nano_flops * 1e9)); - tensorflow::StringPiece sp(throughput); + absl::string_view sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case - tensorflow::str_util::EndsWith(sp, "b")) { + if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case + absl::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } - throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); + throughput += absl::StrCat(op_prefix, "OP/s"); return throughput; } } // namespace @@ -260,8 +259,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) { return HumanReadableNumOps(trops, nanoseconds, "TR"); } -void LogLines(int sev, tensorflow::StringPiece text, const char* fname, - int lineno) { +void LogLines(int sev, absl::string_view text, const char* fname, int lineno) { const int orig_sev = sev; if (sev == tensorflow::FATAL) { sev = tensorflow::ERROR; @@ -275,7 +273,7 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, size_t cur = 0; while (cur < text.size()) { size_t eol = text.find('\n', cur); - if (eol == tensorflow::StringPiece::npos) { + if (eol == absl::string_view::npos) { eol = text.size(); } auto msg = text.substr(cur, eol - cur); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index efeafbc53a..671ef17f36 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -26,16 +26,16 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -202,8 +202,8 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice<D> dest, int64 dest_base, // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. -Status AddStatus(Status prior, tensorflow::StringPiece context); -Status AppendStatus(Status prior, tensorflow::StringPiece context); +Status AddStatus(Status prior, absl::string_view context); +Status AppendStatus(Status prior, absl::string_view context); // Status error shorthands -- printfs the arguments to be // used as an error message and returns a status in the canonical @@ -222,26 +222,26 @@ Status InvalidArgumentV(const char* format, va_list args); template <typename... Args> Status InvalidArgumentStrCat(Args&&... concat) { - return InvalidArgument( - "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str()); + return InvalidArgument("%s", + absl::StrCat(std::forward<Args>(concat)...).c_str()); } template <typename... Args> Status UnimplementedStrCat(Args&&... concat) { - return Unimplemented( - "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str()); + return Unimplemented("%s", + absl::StrCat(std::forward<Args>(concat)...).c_str()); } template <typename... Args> Status InternalErrorStrCat(Args&&... concat) { - return InternalError( - "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str()); + return InternalError("%s", + absl::StrCat(std::forward<Args>(concat)...).c_str()); } template <typename... Args> Status ResourceExhaustedStrCat(Args&&... concat) { - return ResourceExhausted( - "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str()); + return ResourceExhausted("%s", + absl::StrCat(std::forward<Args>(concat)...).c_str()); } // Splits the lines of the original, replaces leading whitespace with the prefix @@ -250,8 +250,7 @@ Status ResourceExhaustedStrCat(Args&&... concat) { // // Note: even different amounts of leading whitespace on different lines will be // uniformly replaced with "indentation". -string Reindent(tensorflow::StringPiece original, - tensorflow::StringPiece indentation); +string Reindent(absl::string_view original, absl::string_view indentation); // Checks whether permutation is a permutation of the [0, rank) integer range. bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank); @@ -313,7 +312,7 @@ string CommaSeparatedString(const Container& c, const char* prefix = "", string comma_separated = prefix; const char* separator = ""; for (const auto& entry : c) { - tensorflow::strings::StrAppend(&comma_separated, separator, entry); + absl::StrAppend(&comma_separated, separator, entry); separator = ", "; } comma_separated += suffix; @@ -395,8 +394,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds); // Split the text into multiple lines and log each line with the given // severity, filename, and line number. -void LogLines(int sev, tensorflow::StringPiece text, const char* fname, - int lineno); +void LogLines(int sev, absl::string_view text, const char* fname, int lineno); template <typename T> inline bool IsPowerOfTwo(T x) { diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index f11123ca24..44fb1bdc38 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -17,10 +17,9 @@ limitations under the License. #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -49,8 +48,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) { } /* static */ string ToString(const WindowDimension& dim) { - using tensorflow::strings::StrAppend; - using tensorflow::strings::StrCat; + using absl::StrAppend; + using absl::StrCat; string str = StrCat("(size=", dim.size()); if (dim.stride() != 1) { StrAppend(&str, ",stride=", dim.stride()); @@ -75,8 +74,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) { } string ToString(const Window& window) { - using tensorflow::strings::StrAppend; - using tensorflow::strings::StrCat; + using absl::StrAppend; + using absl::StrCat; string str; const auto add_field = diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index ab4328d459..66983801bf 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -181,6 +181,7 @@ cc_library( "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/data:dataset_ops_op_lib", + "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/hadoop:dataset_ops_op_lib", diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index f7dd3183b0..8d314250a0 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -310,7 +310,9 @@ class ControlFlowTransformer(converter.Base): template = """ def extra_test_name(state_ssf): return extra_test_expr - def body_name(iterate, state_ssf): + def body_name(loop_vars, state_ssf): + # Workaround for PEP-3113 + iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 02bc00dbc8..2a6f3cb395 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -217,5 +217,13 @@ class ControlFlowTest(converter_testing.TestCase): with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx) + def test_for_tuple_unpacking(self): + def test_fn(x_list): + z = tf.constant(0) # pylint:disable=undefined-variable + for i, x in enumerate(x_list): + z = z + x + i + return z + + self.assertTransformedResult(test_fn, [3, 3], 7) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD index 9ef1ac9663..29a92444bb 100644 --- a/tensorflow/contrib/autograph/pyct/testing/BUILD +++ b/tensorflow/contrib/autograph/pyct/testing/BUILD @@ -34,8 +34,10 @@ py_test( srcs = ["codegen_test.py"], srcs_version = "PY2AND3", tags = [ + "manual", "no_windows", "nomsan", + "notap", ], deps = [ ":testing", diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 5821d51bca..5e6c1520a2 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -25,6 +25,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@LMDBDataset @@RandomDataset @@Reducer @@SqlDataset @@ -49,6 +50,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave +@@parse_example_dataset @@prefetch_to_device @@read_batch_features @@rejection_resample @@ -89,10 +91,12 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset +from tensorflow.contrib.data.python.ops.readers import LMDBDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 4d1603a561..ec6cb37193 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -77,6 +77,17 @@ cc_library( ) cc_library( + name = "lmdb_dataset_op", + srcs = ["lmdb_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@lmdb", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( name = "threadpool_dataset_op", srcs = ["threadpool_dataset_op.cc"], deps = [ @@ -117,6 +128,7 @@ cc_library( ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":indexed_dataset", + ":lmdb_dataset_op", ":prefetching_kernels", ":threadpool_dataset_op", ":unique_dataset_op", diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc new file mode 100644 index 0000000000..80f39992fb --- /dev/null +++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc @@ -0,0 +1,215 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <sys/stat.h> + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/platform/file_system.h" + +#include "lmdb.h" // NOLINT(build/include) + +namespace tensorflow { +namespace { + +class LMDBDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + std::vector<string> filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat<string>()(i)); + } + + *output = new Dataset(ctx, filenames); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::vector<string>& filenames) + : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::LMDB")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = + new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}, {}}); + return *shapes; + } + + string DebugString() const override { return "LMDBDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + if (mdb_cursor_) { + Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); + key_tensor.scalar<string>()() = string( + static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size); + out_tensors->emplace_back(std::move(key_tensor)); + + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar<string>()() = + string(static_cast<const char*>(mdb_value_.mv_data), + mdb_value_.mv_size); + out_tensors->emplace_back(std::move(value_tensor)); + + int val; + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + ++current_file_index_; + } + *end_of_sequence = false; + return Status::OK(); + } + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + private: + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + const string& filename = dataset()->filenames_[current_file_index_]; + + int val = mdb_env_create(&mdb_env_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; + } + val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + } + return Status::OK(); + } + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + mutex mu_; + size_t current_file_index_ GUARDED_BY(mu_) = 0; + MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; + MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; + MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; + MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; + + MDB_val mdb_key_ GUARDED_BY(mu_); + MDB_val mdb_value_ GUARDED_BY(mu_); + }; + + const std::vector<string> filenames_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index cc5e250ea1..ae104d55bd 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -266,4 +266,13 @@ REGISTER_OP("AssertNextDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("LMDBDataset") + .Input("filenames: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 803a3b33fa..9e2697534c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "batch_dataset_op_test", @@ -194,6 +195,31 @@ py_test( ) py_test( + name = "lmdb_dataset_op_test", + size = "medium", + srcs = ["lmdb_dataset_op_test.py"], + data = ["//tensorflow/core:lmdb_testdata"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//third_party/py/numpy", + ], +) + +py_test( name = "map_dataset_op_test", size = "medium", srcs = ["map_dataset_op_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py new file mode 100644 index 0000000000..7bc582ebaa --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LMDBDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +prefix_path = "tensorflow/core/lib" + + +class LMDBDatasetTest(test.TestCase): + + def setUp(self): + super(LMDBDatasetTest, self).setUp() + # Copy database out because we need the path to be writable to use locks. + path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb") + self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") + shutil.copy(path, self.db_path) + + def testReadFromFile(self): + filename = self.db_path + + filenames = constant_op.constant([filename], dtypes.string) + num_repeats = 2 + + dataset = readers.LMDBDataset(filenames).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(num_repeats): # Dataset is repeated. + for i in range(10): # 10 records. + k = compat.as_bytes(str(i)) + v = compat.as_bytes(str(chr(ord("a") + i))) + self.assertEqual((k, v), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 7b9ea191a4..4881f63ab9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -318,6 +318,19 @@ py_test( ) py_test( + name = "parse_example_dataset_serialization_test", + size = "medium", + srcs = ["parse_example_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + ], +) + +py_test( name = "prefetch_dataset_serialization_test", size = "small", srcs = ["prefetch_dataset_serialization_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 9fdbcb66bf..595cecef4d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -510,7 +510,6 @@ class DatasetSerializationTestBase(test.TestCase): else: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) - get_next_op = remove_variants(get_next_op) return init_op, get_next_op, saver for i in range(len(break_points) + 1): @@ -616,29 +615,40 @@ class DatasetSerializationTestBase(test.TestCase): # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - - # TODO(shivaniagrwal): `output_classes` is a nested structure of classes, - # this base class is specific to current test cases. Update when tests are - # added with `output_classes` as a nested structure with at least one of the - # component being `tf.SparseTensor`. - if (sparse_tensors or - self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. ops.add_to_collection("iterator_ops", get_next.indices) ops.add_to_collection("iterator_ops", get_next.values) ops.add_to_collection("iterator_ops", get_next.dense_shape) - else: - for el in nest.flatten(get_next): - ops.add_to_collection("iterator_ops", el) + return + + get_next_list = nest.flatten(get_next) + for i, output_class in enumerate( + nest.flatten(self._get_output_classes(ds_fn))): + if output_class is sparse_tensor.SparseTensor: + ops.add_to_collection("iterator_ops", get_next_list[i].indices) + ops.add_to_collection("iterator_ops", get_next_list[i].values) + ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) + else: + ops.add_to_collection("iterator_ops", get_next_list[i]) def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - if (sparse_tensors or - self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. init_op, indices, values, dense_shape = all_ops return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) - else: - return all_ops[0], nest.pack_sequence_as( - self._get_output_types(ds_fn), all_ops[1:]) + get_next_list = [] + i = 1 + for output_class in nest.flatten(self._get_output_classes(ds_fn)): + if output_class is sparse_tensor.SparseTensor: + indices, values, dense_shape = all_ops[i:i + 3] + i += 3 + get_next_list.append( + sparse_tensor.SparseTensor(indices, values, dense_shape)) + else: + get_next_list.append(all_ops[i]) + i += 1 + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), get_next_list) def _get_output_types(self, ds_fn): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py new file mode 100644 index 0000000000..d3fa84e74c --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParseExampleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.platform import test + + +class ParseExampleDatasetSerializationTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def ParseExampleDataset(self, num_repeat, batch_size): + return self.make_batch_feature( + filenames=self.test_filenames, + num_epochs=num_repeat, + batch_size=batch_size, + reader_num_threads=5, + parser_num_threads=10) + + def testSerializationCore(self): + num_repeat = 5 + batch_size = 2 + num_outputs = self._num_records * self._num_files * num_repeat // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self.ParseExampleDataset( + num_repeat=num_repeat, batch_size=batch_size), + lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 0bd5b403e2..4b45cc7e36 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -82,7 +82,6 @@ py_library( ":interleave_ops", ":parsing_ops", ":shuffle_ops", - ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index cca9bf6742..54a92ab185 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -216,25 +216,46 @@ def sample_from_datasets(datasets, weights=None, seed=None): length of the `datasets` element. """ num_datasets = len(datasets) - if weights is None: - weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat() - elif not isinstance(weights, dataset_ops.Dataset): - weights = ops.convert_to_tensor(weights, name="weights") - if weights.dtype not in (dtypes.float32, dtypes.float64): - raise TypeError("`weights` must be convertible to a tensor of " - "`tf.float32` or `tf.float64` elements.") - if not weights.shape.is_compatible_with([num_datasets]): - raise ValueError("`weights` must be a vector of length `len(datasets)`.") - weights = dataset_ops.Dataset.from_tensors(weights).repeat() - - # The `stateless_multinomial()` op expects log-probabilities, as opposed to - # weights. - logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) - def select_dataset(logits, seed): - return array_ops.squeeze( - stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) + if not isinstance(weights, dataset_ops.Dataset): + if weights is None: + # Select inputs with uniform probability. + logits = [[1.0] * num_datasets] + else: + # Use the given `weights` as the probability of choosing the respective + # input. + weights = ops.convert_to_tensor(weights, name="weights") + if weights.dtype not in (dtypes.float32, dtypes.float64): + raise TypeError("`weights` must be convertible to a tensor of " + "`tf.float32` or `tf.float64` elements.") + if not weights.shape.is_compatible_with([num_datasets]): + raise ValueError( + "`weights` must be a vector of length `len(datasets)`.") + + # The `stateless_multinomial()` op expects log-probabilities, as opposed + # to weights. + logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) + + def select_dataset_constant_logits(seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + + selector_input = random_ops.RandomDataset(seed).batch(2).map( + select_dataset_constant_logits) + else: + # Use each element of the given `weights` dataset as the probability of + # choosing the respective input. + + # The `stateless_multinomial()` op expects log-probabilities, as opposed to + # weights. + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + + def select_dataset_varying_logits(logits, seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + + selector_input = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2) + )).map(select_dataset_varying_logits) return _DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py index f868653554..2701605e64 100644 --- a/tensorflow/contrib/data/python/ops/parsing_ops.py +++ b/tensorflow/contrib/data/python/ops/parsing_ops.py @@ -102,8 +102,6 @@ class _ParseExampleDataset(dataset_ops.Dataset): return self._output_classes -# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. # TODO(b/111553342): add arguments names and example names as well. def parse_example_dataset(features, num_parallel_calls=1): """A transformation that parses `Example` protos into a `dict` of tensors. diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index cafe0a4091..29005859d7 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -27,7 +27,6 @@ from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_da from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import parsing_ops from tensorflow.contrib.data.python.ops import shuffle_ops -from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import convert @@ -326,7 +325,6 @@ def make_csv_dataset( shuffle_seed=None, prefetch_buffer_size=1, num_parallel_reads=1, - num_parallel_parser_calls=2, sloppy=False, num_rows_for_inference=100, compression_type=None, @@ -393,8 +391,6 @@ def make_csv_dataset( batches consumed per training step. num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. - num_parallel_parser_calls: Number of parallel invocations of the CSV parsing - function on CSV records. sloppy: If `True`, reading performance will be improved at the cost of non-deterministic ordering. If `False`, the order of elements produced is deterministic prior to shuffling (elements are still @@ -503,7 +499,7 @@ def make_csv_dataset( # indefinitely, and all batches will be full-sized. dataset = dataset.batch(batch_size=batch_size, drop_remainder=num_epochs is None) - dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) + dataset = dataset.map(map_fn) dataset = dataset.prefetch(prefetch_buffer_size) return dataset @@ -972,3 +968,49 @@ class SqlDataset(dataset_ops.Dataset): @property def output_types(self): return self._output_types + + +class LMDBDataset(dataset_ops.Dataset): + """A LMDB Dataset that reads the lmdb file.""" + + def __init__(self, filenames): + """Create a `LMDBDataset`. + + `LMDBDataset` allows a user to read data from a mdb file as + (key value) pairs sequentially. + For example: + ```python + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + ``` + Args: + filenames: A `tf.string` tensor containing one or more filenames. + """ + super(LMDBDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + + def _as_variant_tensor(self): + return contrib_gen_dataset_ops.lmdb_dataset( + self._filenames, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_classes(self): + return ops.Tensor, ops.Tensor + + @property + def output_shapes(self): + return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + + @property + def output_types(self): + return dtypes.string, dtypes.string diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index bcf9b3c568..8173b5d4ba 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -106,6 +106,38 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + additional_deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + ":values", + "@absl_py//absl/testing:parameterized", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -239,35 +271,6 @@ py_test( ], ) -py_test( - name = "parameter_server_strategy_test", - srcs = ["parameter_server_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":combinations", - ":multi_worker_test_base", - ":parameter_server_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:layers", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:estimator_py", - "@absl_py//absl/testing:parameterized", - ], -) - cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index a411ca870e..2331444261 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -68,11 +68,11 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._cluster_spec = multi_worker_util.normalize_cluster_spec( cluster_spec) worker_device = "/job:%s/task:%d" % (task_type, task_id) - num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) - if "chief" in self._cluster_spec.as_dict(): - num_workers += 1 + num_workers = len(self._cluster_spec.as_dict().get("worker", [])) + len( + self._cluster_spec.as_dict().get("chief", [])) if not num_workers: - raise ValueError("`task_type` shoud be in `cluster_spec`.") + raise ValueError("No `worker` or `chief` tasks can be found in " + "`cluster_spec`.") self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index c679fc8810..0d966d0e90 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -25,10 +25,8 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context -from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -41,52 +39,43 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class DistributedCollectiveAllReduceStrategyTest( - multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): +class CollectiveAllReduceStrategyTestBase( + multi_worker_test_base.MultiWorkerTestBase): collective_key_base = 0 - @classmethod - def setUpClass(cls): - """Create a local cluster with 2 workers.""" - cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=0) - cls._cluster_spec = { - run_config.TaskType.WORKER: [ - 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' - ] - } - def setUp(self): self._run_options = config_pb2.RunOptions() self._run_options.experimental.collective_graph_key = 6 self._sess_config = config_pb2.ConfigProto() - self._sess_config.experimental.collective_group_leader = ( - '/job:worker/replica:0/task:0') # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different # tests. - DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 - super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 + super(CollectiveAllReduceStrategyTestBase, self).setUp() def _get_test_object(self, task_type, task_id, num_gpus=0): distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( num_gpus_per_worker=num_gpus) - distribution.configure( - cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) + if task_type and task_id is not None: + distribution.configure( + cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) collective_keys = cross_tower_utils.CollectiveKeys( group_key_start=10 * num_gpus + - DistributedCollectiveAllReduceStrategyTest.collective_key_base, + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_start=num_gpus * 100 + - DistributedCollectiveAllReduceStrategyTest.collective_key_base, + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + - DistributedCollectiveAllReduceStrategyTest.collective_key_base) + CollectiveAllReduceStrategyTestBase.collective_key_base) distribution._collective_keys = collective_keys distribution._cross_tower_ops._collective_keys = collective_keys - return distribution, self._workers[task_id].target + if task_type and task_id is not None: + return distribution, 'grpc://' + self._cluster_spec[task_type][task_id] + else: + return distribution, '' def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target = self._get_test_object(task_type, task_id, num_gpus) @@ -154,12 +143,6 @@ class DistributedCollectiveAllReduceStrategyTest( self.assertLess(error_after, error_before) return error_after < error_before - @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) - def _test_variable_initialization(self, task_type, task_id, num_gpus): distribution, master_target = self._get_test_object(task_type, task_id, num_gpus) @@ -184,13 +167,35 @@ class DistributedCollectiveAllReduceStrategyTest( sess.run( variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( [x, reduced_x], options=self._run_options) self.assertTrue(np.array_equal(x_value, reduced_x_value)) return np.array_equal(x_value, reduced_x_value) + +class DistributedCollectiveAllReduceStrategyTest( + CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 3 workers.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + + def setUp(self): + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: return @@ -200,16 +205,46 @@ class DistributedCollectiveAllReduceStrategyTest( num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class DistributedCollectiveAllReduceStrategyTestWithChief( + CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 3 workers and 1 chief.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0, has_chief=True) + + def setUp(self): + super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp() + self._run_options.experimental.collective_graph_key = 7 + self._sess_config.experimental.collective_group_leader = ( + '/job:chief/replica:0/task:0') + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: return - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus) - self._test_minimize_loss_graph(distribution) + self._test_minimize_loss_graph(None, None, num_gpus) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 97c4778f0d..2ad91d56e9 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -32,7 +32,6 @@ from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test -from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -379,12 +378,16 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, distribution=[ combinations.NamedDistribution( "MirroredCPU", - lambda: mirrored_strategy.MirroredStrategy(["/cpu:0"]), - required_gpus=2), + lambda: mirrored_strategy.MirroredStrategy(num_gpus=0), + required_gpus=0), combinations.NamedDistribution( "Mirrored1GPU", - lambda: mirrored_strategy.MirroredStrategy(["/gpu:1"]), - required_gpus=2), combinations.mirrored_strategy_with_two_gpus + lambda: mirrored_strategy.MirroredStrategy(num_gpus=1), + required_gpus=1), + combinations.NamedDistribution( + "Mirrored2GPUs", + lambda: mirrored_strategy.MirroredStrategy(num_gpus=2), + required_gpus=2), ], mode=["graph"]) @@ -406,13 +409,8 @@ class MultiWorkerCollectiveAllReduceTest( @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) - cls._cluster_spec = { - run_config.TaskType.WORKER: [ - "fake_worker_0", "fake_worker_1", "fake_worker_2" - ] - } def setUp(self): super(MultiWorkerCollectiveAllReduceTest, self).setUp() @@ -446,7 +444,8 @@ class MultiWorkerCollectiveAllReduceTest( ] else: devices = ["/job:%s/task:%d" % (task_type, task_id)] - return collective_all_reduce_ops, devices, self._workers[task_id].target + return (collective_all_reduce_ops, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) def _assert_values_equal(self, left, right, sess): if isinstance(left, list): @@ -473,7 +472,8 @@ class MultiWorkerCollectiveAllReduceTest( num_workers = 1 worker_device = None else: - num_workers = len(self._workers) + num_workers = len(self._cluster_spec.get("chief", [])) + len( + self._cluster_spec.get("worker", [])) worker_device = "/job:%s/task:%d" % (task_type, task_id) with ops.Graph().as_default(), \ ops.device(worker_device), \ @@ -551,7 +551,7 @@ class MultiWorkerCollectiveAllReduceTest( return True @combinations.generate( - combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) def testReductionDistributed(self, num_gpus): if context.num_gpus() < num_gpus: return diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index ecaf60f350..e87b48ba41 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -276,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): else: result = values.MirroredVariable(index, index[devices[0]], aggregation) + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables @@ -289,6 +292,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): for v in index.values(): l.remove(v) g.add_to_collections(collections, result) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) + return result diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 612655a38a..ac2697958d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -888,8 +888,18 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) - mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + + # read_value == True + mirrored_var_result = self.evaluate( + mirrored_var.assign_add(6.0, read_value=True)) self.assertEquals(7.0, mirrored_var_result) + self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + + # read_value == False + self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) + self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarTowerContext(self): @@ -956,6 +966,8 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) self.assertEquals(3.0, mirrored_var_result) + self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarTowerContext(self): @@ -1262,5 +1274,22 @@ class MultiWorkerMirroredStrategyTest( self._test_minimize_loss_graph(self._get_distribution_strategy()) +class MultiWorkerMirroredStrategyTestWithChief( + multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers and 1 chief.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=2, num_ps=0, has_chief=True) + cls._default_target = "grpc://" + cls._cluster_spec["chief"][0] + + def testMinimizeLossGraph(self): + strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_minimize_loss_graph(strategy) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 3f44ab7700..969e126956 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -62,6 +62,7 @@ class VariableCreatorStackTest(test.TestCase): def model_fn(device_id): assert isinstance(device_id, int) + def thread_creator_fn(next_creator, *args, **kwargs): return next_creator(*args, **kwargs) + ":thread_" + str(device_id) @@ -93,16 +94,15 @@ class MultiWorkerMirroredStrategyTest(test.TestCase): def testDeviceScope(self): """Test the device scope of multi-worker MirroredStrategy.""" with context.graph_mode(): - strategy = mirrored_strategy.MirroredStrategy( - num_gpus=context.num_gpus()) + strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) strategy.configure( cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]}) with strategy.scope(): a = constant_op.constant(1.) - with ops.device('/cpu:0'): + with ops.device("/cpu:0"): b = constant_op.constant(1.) - self.assertEqual(a.device, '/job:worker/task:0') - self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + self.assertEqual(a.device, "/job:worker/task:0") + self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 249de01f08..18b4503eff 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -23,26 +23,105 @@ import copy import threading import numpy as np +_portpicker_import_error = None +try: + import portpicker # pylint: disable=g-import-not-at-top +except ImportError as _error: # pylint: disable=invalid-name + _portpicker_import_error = _error + portpicker = None + +# pylint: disable=g-import-not-at-top from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test -from tensorflow.python.framework import test_util - - -def create_in_process_cluster(num_workers, num_ps): +from tensorflow.python.training import server_lib + + +def _create_cluster(num_workers, + num_ps, + has_chief=False, + has_eval=False, + protocol='grpc', + worker_config=None, + ps_config=None): + """Creates and starts local servers and returns the cluster_spec dict.""" + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + + cluster_dict = {} + if num_workers > 0: + cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports] + if num_ps > 0: + cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] + if has_eval: + cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()] + if has_chief: + cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()] + + cs = server_lib.ClusterSpec(cluster_dict) + + for i in range(num_workers): + server_lib.Server( + cs, + job_name='worker', + protocol=protocol, + task_index=i, + config=worker_config, + start=True) + + for i in range(num_ps): + server_lib.Server( + cs, + job_name='ps', + protocol=protocol, + task_index=i, + config=ps_config, + start=True) + + if has_chief: + server_lib.Server( + cs, + job_name='chief', + protocol=protocol, + task_index=0, + config=worker_config, + start=True) + + if has_eval: + server_lib.Server( + cs, + job_name='evaluator', + protocol=protocol, + task_index=0, + config=worker_config, + start=True) + + return cluster_dict + + +def create_in_process_cluster(num_workers, + num_ps, + has_chief=False, + has_eval=False): """Create an in-process cluster that consists of only standard server.""" # Leave some memory for cuda runtime. - gpu_mem_frac = 0.7 / num_workers + gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval)) worker_config = config_pb2.ConfigProto() worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac # Enable collective ops which has no impact on non-collective ops. # TODO(yuefengz, tucker): removing this after we move the initialization of # collective mgr to the session level. - worker_config.experimental.collective_group_leader = ( - '/job:worker/replica:0/task:0') + if has_chief: + worker_config.experimental.collective_group_leader = ( + '/job:chief/replica:0/task:0') + else: + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') ps_config = config_pb2.ConfigProto() ps_config.device_count['GPU'] = 0 @@ -56,9 +135,10 @@ def create_in_process_cluster(num_workers, num_ps): # 2) there is something global in CUDA such that if we initialize CUDA in the # parent process, the child process cannot initialize it again and thus cannot # use GPUs (https://stackoverflow.com/questions/22950047). - return test_util.create_local_cluster( + return _create_cluster( num_workers, num_ps=num_ps, + has_chief=has_chief, worker_config=worker_config, ps_config=ps_config, protocol='grpc') @@ -70,7 +150,8 @@ class MultiWorkerTestBase(test.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0) + cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0] def setUp(self): # We only cache the session in one test because another test may have a @@ -111,17 +192,17 @@ class MultiWorkerTestBase(test.TestCase): config.graph_options.rewrite_options.constant_folding = ( rewriter_config_pb2.RewriterConfig.OFF) + if target is None: + target = self._default_target if graph is None: if getattr(self._thread_local, 'cached_session', None) is None: self._thread_local.cached_session = session.Session( - graph=None, config=config, target=target or self._workers[0].target) + graph=None, config=config, target=target) sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: - with session.Session( - graph=graph, config=config, target=target or - self._workers[0].target) as sess: + with session.Session(graph=graph, config=config, target=target) as sess: yield sess def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 96b6519bc4..361c8be590 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -94,11 +95,18 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): cluster configurations. task_type: the current task type. task_id: the current task id. + + Raises: + ValueError: if `cluster_spec` is given but `task_type` or `task_id` is + not. """ super(ParameterServerStrategy, self).__init__() self._num_gpus_per_worker = num_gpus_per_worker if cluster_spec: cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) + if task_type is None or task_id is None: + raise ValueError("When `cluster_spec` is given, must also specify " + "`task_type` and `task_id`.") self._cluster_spec = cluster_spec # We typically don't need to do all-reduce in this strategy. @@ -233,8 +241,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): " for variable: " + kwargs["name"]) def var_creator(*args, **kwargs): + # Record what collections this variable should be added to. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Create and wrap the variable. v = next_creator(*args, **kwargs) - return values.AggregatingVariable(v, aggregation) + wrapped = values.AggregatingVariable(v, aggregation) + + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the contained + # variable to the TRAINABLE_VARIABLES collection, so we manually + # remove it and replace with the wrapper. We can't set "trainable" + # to False for next_creator() since that causes functions like + # implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + l.remove(v) + g.add_to_collections(collections, wrapped) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) + + return wrapped else: var_creator = next_creator @@ -345,6 +380,10 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): cluster configurations. task_type: the current task type. task_id: the current task id. + + Raises: + ValueError: if `cluster_spec` is given but `task_type` or `task_id` is + not. """ del session_config @@ -353,6 +392,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if not self._cluster_spec and cluster_spec: self._cluster_spec = multi_worker_util.normalize_cluster_spec( cluster_spec) + if task_type is None or task_id is None: + raise ValueError("When `cluster_spec` is given, must also specify " + "`task_type` and `task_id`.") self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index adfe3e8b02..0e2bfcec5f 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -24,6 +24,8 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op @@ -37,21 +39,15 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import training_util +CHIEF = run_config.TaskType.CHIEF +WORKER = run_config.TaskType.WORKER +PS = run_config.TaskType.PS -class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, - parameterized.TestCase): - @classmethod - def setUpClass(cls): - cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=2) - cls._cluster_spec = { - run_config.TaskType.WORKER: [ - 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' - ], - run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] - } +class ParameterServerStrategyTestBase( + multi_worker_test_base.MultiWorkerTestBase): def setUp(self): self._result = 0 @@ -60,7 +56,7 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self._init_reached = 0 self._finish_condition = threading.Condition() self._finish_reached = 0 - super(ParameterServerStrategyTest, self).setUp() + super(ParameterServerStrategyTestBase, self).setUp() def _get_test_objects(self, task_type, task_id, num_gpus): distribution = parameter_server_strategy.ParameterServerStrategy( @@ -70,13 +66,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, distribution.configure( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) - return distribution, self._workers[task_id].target + return distribution, 'grpc://' + self._cluster_spec[WORKER][task_id] def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) d, _ = self._get_test_objects(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.test_session(target=self._workers[0].target) as sess, \ + self.test_session(target=self._default_target) as sess, \ d.scope(): # Define a variable outside the call_for_each_tower scope. This is not @@ -172,18 +168,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributed(self, num_gpus): - self._test_device_assignment_distributed('worker', 1, num_gpus) - def _test_device_assignment_local(self, d, compute_device='CPU', variable_device='CPU', num_gpus=0): with ops.Graph().as_default(), \ - self.test_session(target=self._workers[0].target) as sess, \ + self.test_session(target=self._default_target) as sess, \ d.scope(): def model_fn(): @@ -276,29 +267,12 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def testDeviceAssignmentLocalCPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=0) - self._test_device_assignment_local( - distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) - - def testDeviceAssignmentLocalOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) - self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) - - def testDeviceAssignmentLocalTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) - def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target = self._get_test_objects(task_type, task_id, num_gpus) if hasattr(d, '_cluster_spec') and d._cluster_spec: - num_workers = len(d._cluster_spec.as_dict().get('worker', - ['dummy_worker'])) + num_workers = len(d._cluster_spec.as_dict().get(WORKER)) + if 'chief' in d._cluster_spec.as_dict(): + num_workers += 1 else: num_workers = 1 with ops.Graph().as_default(), \ @@ -357,6 +331,11 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + assert hasattr(d, '_cluster_spec') and d._cluster_spec + num_workers = len(d._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d._cluster_spec.as_dict(): + num_workers += 1 + with ops.Graph().as_default(), \ self.test_session(target=master_target) as sess, \ d.scope(): @@ -405,13 +384,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, if context.num_gpus() < d._num_gpus_per_worker: return True - if task_id == 0: + if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. self._init_condition.acquire() self._init_reached += 1 - while self._init_reached != 3: + while self._init_reached != num_workers: self._init_condition.wait() self._init_condition.notify_all() self._init_condition.release() @@ -428,9 +407,42 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self.assertLess(error_after, error_before) return error_after < error_before + +class ParameterServerStrategyTest(ParameterServerStrategyTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2) + cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] + + def testDeviceAssignmentLocalCPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=0) + self._test_device_assignment_local( + distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + + def testDeviceAssignmentLocalOneGPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=1) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + + def testDeviceAssignmentLocalTwoGPUs(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) + def testSimpleBetweenGraph(self): self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, 0) + self._cluster_spec, context.num_gpus()) @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) @@ -444,5 +456,38 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self._cluster_spec, num_gpus) +class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2, has_chief=True) + cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] + + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, context.num_gpus()) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def testGlobalStepIsWrapped(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + with ops.Graph().as_default(), distribution.scope(): + created_step = training_util.create_global_step() + get_step = training_util.get_global_step() + self.assertEqual(created_step, get_step, + msg=('created_step %s type %s vs. get_step %s type %s' % + (id(created_step), created_step.__class__.__name__, + id(get_step), get_step.__class__.__name__))) + self.assertIs(values.AggregatingVariable, type(created_step)) + self.assertIs(values.AggregatingVariable, type(get_step)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index a486003076..6202a0750a 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver): class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, tpu_cluster_resolver, steps_per_run): + def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): """Initializes the TPUStrategy object. Args: @@ -70,6 +70,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. + num_cores: Number of cores to use on the TPU. If None specified, then + auto-detect the cores and topology of the TPU system. """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. @@ -77,13 +79,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + self._num_cores_override = num_cores - # TODO(priyag): This should not be hardcoded here. - self._host = '/device:CPU:0' # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run + # TODO(frankchn): This should not be hardcoded here for pod purposes. + self._host = self.tpu_host_cpu_device(0) + def distribute_dataset(self, dataset_fn): # TODO(priyag): Perhaps distribute across cores here. return self._call_dataset_fn(dataset_fn) @@ -106,6 +110,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] + # TODO(sourabhbajaj): Add support for TPU pods with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. @@ -258,4 +263,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): @property def num_towers(self): - return self._tpu_metadata.num_of_cores_per_host + return self._num_cores_override or self._tpu_metadata.num_cores + + def tpu_host_cpu_device(self, host_id): + if self._tpu_cluster_resolver.get_master() in ('', 'local'): + return '/replica:0/task:0/device:CPU:0' + return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id) + diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index a58bb3a849..e73d9c193e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -183,6 +183,14 @@ class Mirrored(DistributedDelegate): return self._index[device] return list(self._index.values())[0] + def _as_graph_element(self): + obj = self.get() + # pylint: disable=protected-access + conv_fn = getattr(obj, "_as_graph_element", None) + if conv_fn and callable(conv_fn): + return conv_fn() + return obj + def _assign_on_device(device, variable, tensor): with ops.device(device): @@ -354,8 +362,19 @@ class MirroredVariable(DistributedVariable, Mirrored, # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. - return distribution_strategy_context.get_distribution_strategy().update( - self, f, *args, **kwargs) + strategy = distribution_strategy_context.get_distribution_strategy() + updates = strategy.update(self, f, *args, **kwargs) + grouped = strategy.group(updates) + if isinstance(updates, DistributedValues) and updates.is_tensor_like: + # Make sure we run all updates. Without this, something like + # session.run(mirrored_var.assign*(...)) may only update one tower. + index = {} + for d in updates.devices: + with ops.device(d), ops.control_dependencies([grouped]): + index[d] = array_ops.identity(updates.get(d)) + return Mirrored(index) + else: + return grouped else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower @@ -1180,6 +1199,10 @@ class AggregatingVariable(checkpointable.CheckpointableBase): def __repr__(self): return repr(self._v) + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py index 49a9afe3f6..31ee36f024 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class MatrixInverseTriLBijectorTest(test.TestCase): """Tests the correctness of the Y = inv(tril) transformation.""" @@ -40,7 +41,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0 return y - @test_util.run_in_graph_and_eager_modes def testComputesCorrectValues(self): inv = bijectors.MatrixInverseTriL(validate_args=True) self.assertEqual("matrix_inverse_tril", inv.name) @@ -62,7 +62,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes def testOneByOneMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[5.]], dtype=np.float32) @@ -81,7 +80,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes def testZeroByZeroMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.eye(0, dtype=np.float32) @@ -100,7 +98,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes def testBatch(self): # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape # (2, 1). @@ -125,20 +122,18 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) - @test_util.run_in_graph_and_eager_modes def testErrorOnInputRankTooLow(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([0.1], dtype=np.float32) rank_error_msg = "must have rank at least 2" - with self.test_session(): - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.forward(x_).eval() - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.inverse(x_).eval() - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.forward(x_)) + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.inverse(x_)) + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) # TODO(b/80481923): Figure out why these assertions fail, and fix them. ## def testErrorOnInputNonSquare(self): @@ -146,55 +141,50 @@ class MatrixInverseTriLBijectorTest(test.TestCase): ## x_ = np.array([[1., 2., 3.], ## [4., 5., 6.]], dtype=np.float32) ## square_error_msg = "must be a square matrix" - ## with self.test_session(): - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.forward(x_).eval() - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.inverse(x_).eval() - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - - @test_util.run_in_graph_and_eager_modes + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.forward(x_)) + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.inverse(x_)) + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) + def testErrorOnInputNotLowerTriangular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 2.], [3., 4.]], dtype=np.float32) triangular_error_msg = "must be lower triangular" - with self.test_session(): - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.forward(x_).eval() - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.inverse(x_).eval() - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - - @test_util.run_in_graph_and_eager_modes + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.forward(x_)) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.inverse(x_)) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) + def testErrorOnInputSingular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 0.], [0., 0.]], dtype=np.float32) nonsingular_error_msg = "must have all diagonal entries nonzero" - with self.test_session(): - with self.assertRaisesOpError(nonsingular_error_msg): - inv.forward(x_).eval() - with self.assertRaisesOpError(nonsingular_error_msg): - inv.inverse(x_).eval() - with self.assertRaisesOpError(nonsingular_error_msg): - inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - with self.assertRaisesOpError(nonsingular_error_msg): - inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.forward(x_)) + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.inverse(x_)) + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py index a188843952..9a88f8f1bc 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -38,23 +38,22 @@ class OrderedBijectorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBijectorVector(self): - with self.cached_session(): - ordered = Ordered() - self.assertEqual("ordered", ordered.name) - x = np.asarray([[2., 3, 4], [4., 8, 13]]) - y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] - self.assertAllClose(y, self.evaluate(ordered.forward(x))) - self.assertAllClose(x, self.evaluate(ordered.inverse(y))) - self.assertAllClose( - np.sum(np.asarray(y)[..., 1:], axis=-1), - self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)), - atol=0., - rtol=1e-7) - self.assertAllClose( - self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)), - self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)), - atol=0., - rtol=1e-7) + ordered = Ordered() + self.assertEqual("ordered", ordered.name) + x = np.asarray([[2., 3, 4], [4., 8, 13]]) + y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] + self.assertAllClose(y, self.evaluate(ordered.forward(x))) + self.assertAllClose(x, self.evaluate(ordered.inverse(y))) + self.assertAllClose( + np.sum(np.asarray(y)[..., 1:], axis=-1), + self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)), + atol=0., + rtol=1e-7) + self.assertAllClose( + self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)), + self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)), + atol=0., + rtol=1e-7) def testBijectorUnknownShape(self): with self.cached_session(): @@ -84,18 +83,17 @@ class OrderedBijectorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testShapeGetters(self): - with self.cached_session(): - x = tensor_shape.TensorShape([4]) - y = tensor_shape.TensorShape([4]) - bijector = Ordered(validate_args=True) - self.assertAllEqual(y, bijector.forward_event_shape(x)) - self.assertAllEqual(y.as_list(), - self.evaluate(bijector.forward_event_shape_tensor( - x.as_list()))) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) - self.assertAllEqual(x.as_list(), - self.evaluate(bijector.inverse_event_shape_tensor( - y.as_list()))) + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([4]) + bijector = Ordered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + self.evaluate(bijector.forward_event_shape_tensor( + x.as_list()))) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + self.evaluate(bijector.inverse_event_shape_tensor( + y.as_list()))) def testBijectiveAndFinite(self): with self.cached_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index d0098c3c10..8dad80aa64 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -43,16 +43,15 @@ class SoftsignBijectorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBijectorBounds(self): bijector = Softsign(validate_args=True) - with self.test_session(): - with self.assertRaisesOpError("greater than -1"): - bijector.inverse(-3.).eval() - with self.assertRaisesOpError("greater than -1"): - bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval() - - with self.assertRaisesOpError("less than 1"): - bijector.inverse(3.).eval() - with self.assertRaisesOpError("less than 1"): - bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() + with self.assertRaisesOpError("greater than -1"): + self.evaluate(bijector.inverse(-3.)) + with self.assertRaisesOpError("greater than -1"): + self.evaluate(bijector.inverse_log_det_jacobian(-3., event_ndims=0)) + + with self.assertRaisesOpError("less than 1"): + self.evaluate(bijector.inverse(3.)) + with self.assertRaisesOpError("less than 1"): + self.evaluate(bijector.inverse_log_det_jacobian(3., event_ndims=0)) @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverse(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index f7b2efa7bc..05f5d30666 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -542,9 +542,9 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +@test_util.run_all_in_graph_and_eager_modes class TestMoveDimension(test.TestCase): - @test_util.run_in_graph_and_eager_modes def test_move_dimension_static_shape(self): x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) @@ -561,7 +561,6 @@ class TestMoveDimension(test.TestCase): x_perm = distribution_util.move_dimension(x, 4, 2) self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) - @test_util.run_in_graph_and_eager_modes def test_move_dimension_dynamic_shape(self): x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index fa3f1bb7ad..84517b57c7 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -14,6 +14,7 @@ py_library( ":datasets", ":metrics", ":network", + ":remote", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -223,11 +224,24 @@ py_test( ], ) +py_library( + name = "remote", + srcs = ["remote.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "//tensorflow/python/eager:context", + ], +) + py_test( name = "remote_test", srcs = ["remote_test.py"], srcs_version = "PY2AND3", deps = [ + ":remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index a28bc8a43d..3f70f573b1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -272,8 +272,8 @@ class ResNet50(tf.keras.Model): else: self.global_pooling = None - def call(self, input_tensor, training): - x = self.conv1(input_tensor) + def call(self, inputs, training=True): + x = self.conv1(inputs) x = self.bn_conv1(x, training=training) x = tf.nn.relu(x) x = self.max_pool(x) diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py new file mode 100644 index 0000000000..b74cf394f6 --- /dev/null +++ b/tensorflow/contrib/eager/python/remote.py @@ -0,0 +1,73 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to connect to remote servers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef +from tensorflow.python.eager import context + + +def connect_to_remote_host(remote_host=None, job_name="worker"): + """Connects to a single machine to enable remote execution on it. + + Will make devices on the remote host available to use. Note that calling this + more than once will work, but will invalidate any tensor handles on the old + remote devices. + + Using the default job_name of worker, you can schedule ops to run remotely as + follows: + ```python + # Enable eager execution, and connect to the remote host. + tf.enable_eager_execution() + tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876") + + with ops.device("job:worker/replica:0/task:1/device:CPU:0"): + # The following tensors should be resident on the remote device, and the op + # will also execute remotely. + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + ``` + + Args: + remote_host: The addr of the remote server in host-port format. + job_name: The job name under which the new server will be accessible. + + Raises: + ValueError: if remote_host is None. + """ + if remote_host is None: + raise ValueError("Must provide an remote_host") + cluster_def = ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "127.0.0.1:0" + job_def.tasks[1] = remote_host + + server_def = ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=0, + protocol="grpc") + + # TODO(nareshmodi): Make this default since it works in more situations. + os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1" + context.set_server_def(server_def) diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 76f48eeb1c..13029db975 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.eager import backprop @@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase): self._cached_server1_target = self._cached_server1.target[len("grpc://"):] self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + def setUp(self): # Start the local server. context.set_server_def( server_def=get_server_def( @@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x1) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + @run_sync_and_async + def testConnectToRemoteServer(self): + """Basic server connection.""" + remote.connect_to_remote_host(self._cached_server1_target) + + with ops.device("job:worker/replica:0/task:1/device:CPU:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 4dfd083443..fe7f1b72fc 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -74,6 +74,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@TensorSpec +@@connect_to_cloud_tpu + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -94,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint +from tensorflow.contrib.eager.python.remote import connect_to_remote_host from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 45a0ded7eb..458a50f25c 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -293,6 +293,7 @@ def generated_test_models(): "topk", "transpose", #"transpose_conv", # disabled due to b/111213074 + "unpack", "where", ] diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 70178b2faa..e81f9e4f51 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -286,6 +286,11 @@ typedef struct { int axis; } TfLiteOneHotParams; +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index 32b1cfd2d8..c39013bb42 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -2434,7 +2434,8 @@ class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { } }; -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, + DISABLED_LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -2541,7 +2542,8 @@ class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { } }; -TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, + DISABLED_LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -3200,7 +3202,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { } }; -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, DISABLED_LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index 776803da8c..f255017ad9 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite APIs diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index d979353bb3..ee6150b60e 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # How to use custom operators diff --git a/tensorflow/contrib/lite/g3doc/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md index d79a2696b4..c38b928684 100644 --- a/tensorflow/contrib/lite/g3doc/demo_android.md +++ b/tensorflow/contrib/lite/g3doc/demo_android.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Android Demo App diff --git a/tensorflow/contrib/lite/g3doc/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md index a554898899..7579ad84a0 100644 --- a/tensorflow/contrib/lite/g3doc/demo_ios.md +++ b/tensorflow/contrib/lite/g3doc/demo_ios.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # iOS Demo App diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md index dc9cc98c08..90e7915c52 100644 --- a/tensorflow/contrib/lite/g3doc/devguide.md +++ b/tensorflow/contrib/lite/g3doc/devguide.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Developer Guide diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index d78d373ccf..5ff0412209 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite for iOS diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index 4ceb9a53dc..b984671e89 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # List of Hosted Models diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md index b06f4fd3b8..0d571ce547 100644 --- a/tensorflow/contrib/lite/g3doc/ops_versioning.md +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite Ops Versioning diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md index be60d7941a..8cf43496df 100644 --- a/tensorflow/contrib/lite/g3doc/overview.md +++ b/tensorflow/contrib/lite/g3doc/overview.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Introduction to TensorFlow Lite diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md index 5cd0aab44f..28cb6aba6e 100644 --- a/tensorflow/contrib/lite/g3doc/performance.md +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Performance diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md index 9fcf79ba00..8ed8640582 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite for Raspberry Pi diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index aa65ec9988..fb9d5f6787 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite & TensorFlow Compatibility Guide @@ -843,6 +841,19 @@ Outputs { } ``` +**UNPACK** + +``` +Inputs { + 0: a tensor. + 1: an integer. + 2: an integer. +} +Outputs { + 0-N: tensors of unpacked tensor. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md index 76e16fc9db..c7cdee07de 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Building TensorFlow on Android diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md index bd047bfcec..d003bb2f38 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Overview diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md index 6223707892..be8b4100c8 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Building TensorFlow on iOS diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md index 4c2071ed05..4d4bb3bc08 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Integrating TensorFlow libraries diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md index a0192c3541..7436594fd8 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Optimizing for mobile diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md index 6b4e4a92bd..d1c67d4c61 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Preparing models for mobile deployment diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 1f528fdab9..407d52f0e8 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -211,6 +211,7 @@ cc_library( "transpose_conv.cc", "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", + "unpack.cc", ], hdrs = [ "padding.h", @@ -1201,6 +1202,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "unpack_test", + size = "small", + srcs = ["unpack_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 40160289c8..7319636bf5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2143,38 +2143,6 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims, gemm_context); } -template <typename T> -inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, - int block_size, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("DepthToSpace"); - - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - - const int output_depth = ArraySize(output_dims, 0); - const int batch_size = ArraySize(output_dims, 3); - - // Number of continuous values that we can copy in one interation. - const int stride = block_size * output_depth; - - for (int batch = 0; batch < batch_size; ++batch) { - for (int in_h = 0; in_h < input_height; ++in_h) { - const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); - for (int offset_h = 0; offset_h < block_size; ++offset_h) { - const T* src = input_ptr; - for (int in_w = 0; in_w < input_width; ++in_w) { - memcpy(output_data, src, stride * sizeof(T)); - output_data += stride; - src += input_depth; - } - input_ptr += stride; - } - } - } -} - // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac, typename T> void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, @@ -2250,25 +2218,87 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, } template <typename T> -inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, +inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + + const int output_depth = output_shape.Dims(3); + const int batch_size = output_shape.Dims(0); + + // Number of continuous values that we can copy in one interation. + const int stride = op_params.block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0); + for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// Legacy Dims<4>. +template <typename T> +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, int block_size, T* output_data, const Dims<4>& output_dims) { + tflite::DepthToSpaceParams op_params; + op_params.block_size = block_size; + + DepthToSpace(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template <typename T> +inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); - const int input_depth = ArraySize(input_dims, 0); - const int batch_size = ArraySize(input_dims, 3); + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + + const int input_depth = input_shape.Dims(3); + const int batch_size = input_shape.Dims(0); // Number of continuous values that we can copy in one interation. - const int stride = block_size * input_depth; + const int stride = op_params.block_size * input_depth; for (int batch = 0; batch < batch_size; ++batch) { for (int out_h = 0; out_h < output_height; ++out_h) { - T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); - for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0); + for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) { T* dst = output_ptr; for (int out_w = 0; out_w < output_width; ++out_w) { memcpy(dst, input_data, stride * sizeof(T)); @@ -2281,6 +2311,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. +template <typename T> +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToDepthParams op_params; + op_params.block_size = block_size; + + SpaceToDepth(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template <FusedActivationFunctionType Ac> void NonGlobalBatchNormalization( const float* input_data, const Dims<4>& input_dims, const float* mean_data, @@ -5565,20 +5607,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim, } template <typename T> -inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, - T* output_data, const Dims<4>& output_dims) { +inline void BatchToSpaceND( + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* crops_data, + const RuntimeShape& unextended_output_shape, T* output_data) { gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + const int block_shape_width = block_shape_data[1]; const int block_shape_height = block_shape_data[0]; const int crops_top = crops_data[0]; @@ -5613,14 +5664,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, spatial_offset % block_shape_width - crops_left; TFLITE_DCHECK_GE(out_w, 0); TFLITE_DCHECK_LT(out_w, output_width); - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); - const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0); + const T* in = + input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0); memcpy(out, in, depth * sizeof(T)); } } } } +// Legacy Dims<4>. +template <typename T> +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* crops_data, const Dims<4>& crops_dims, + T* output_data, const Dims<4>& output_dims) { + BatchToSpaceND(DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(crops_dims), crops_data, DimsToShape(output_dims), + output_data); +} + template <typename T> void TypedMemset(void* ptr, T value, size_t num) { // Optimization for common cases where memset() will suffice. diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index a6aef4fa29..3492a6c2f9 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -407,18 +407,29 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, } template <typename T> -inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, - int block_size, T* output_data, - const Dims<4>& output_dims) { - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int input_batch = ArraySize(input_dims, 3); +inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_batch = ArraySize(output_dims, 3); + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_batch = input_shape.Dims(0); + + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch = output_shape.Dims(0); + + const int32 block_size = op_params.block_size; TFLITE_DCHECK_EQ(input_width * block_size, output_width); TFLITE_DCHECK_EQ(input_height * block_size, output_height); @@ -437,9 +448,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, const int in_h = out_h / block_size; const int in_b = out_b; + const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d); const int output_index = - Offset(output_dims, out_d, out_w, out_h, out_b); - const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + Offset(output_shape, out_b, out_h, out_w, out_d); output_data[output_index] = input_data[input_index]; } @@ -448,19 +459,42 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. template <typename T> -inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, int block_size, T* output_data, const Dims<4>& output_dims) { - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int input_batch = ArraySize(input_dims, 3); + tflite::DepthToSpaceParams op_params; + op_params.block_size = block_size; - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_batch = ArraySize(output_dims, 3); + DepthToSpace(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template <typename T> +inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_batch = input_shape.Dims(0); + + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch = output_shape.Dims(0); + + const int32 block_size = op_params.block_size; TFLITE_DCHECK_EQ(input_width, output_width * block_size); TFLITE_DCHECK_EQ(input_height, output_height * block_size); @@ -478,9 +512,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, const int out_h = in_h / block_size; const int out_b = in_b; + const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d); const int output_index = - Offset(output_dims, out_d, out_w, out_h, out_b); - const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + Offset(output_shape, out_b, out_h, out_w, out_d); output_data[output_index] = input_data[input_index]; } @@ -489,6 +523,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. +template <typename T> +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToDepthParams op_params; + op_params.block_size = block_size; + + SpaceToDepth(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, const float* weights_data, const Dims<4>& weights_dims, const float* bias_data, @@ -2034,6 +2080,25 @@ void Pack(int dim, const Scalar* const* input_data, } } +template <typename Scalar> +void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims, + int dimensions, int outputs_count, Scalar* const* output_datas, + const Dims<4>& output_dims) { + int outer_size = 1; + for (int i = dimensions - axis; i < 4; i++) { + outer_size *= input_dims.sizes[i]; + } + + const int copy_size = FlatSize(input_dims) / outer_size / outputs_count; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < outputs_count; ++i) { + Scalar* output_ptr = output_datas[i] + copy_size * k; + int loc = k * outputs_count * copy_size + i * copy_size; + memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar)); + } + } +} + // TODO(prabhumk): This is the same as the optimized implementation. // TODO(prabhumk): The quantized implementation of concatentation isn't fully // quantized as it takes scale as a floating point value. This should be fixed @@ -3467,45 +3532,56 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, } template <typename T> -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims, - const int32_t pad_value) { - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); +inline void SpaceToBatchND( + const SpaceToBatchParams& params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* paddings_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + const int block_shape_height = block_shape_data[0]; const int block_shape_width = block_shape_data[1]; const int padding_top = paddings_data[0]; const int padding_left = paddings_data[2]; + // For uint8 quantized, the correct padding "zero value" is the output offset. + const int32_t pad_value = params.output_offset; + for (int out_b = 0; out_b < output_batch_size; ++out_b) { int input_batch = out_b % input_batch_size; int shift_w = (out_b / input_batch_size) % block_shape_width; int shift_h = (out_b / input_batch_size) / block_shape_width; for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0); if (out_h * block_shape_height + shift_h < padding_top || out_h * block_shape_height + shift_h >= padding_top + input_height || out_w * block_shape_width + shift_w < padding_left || out_w * block_shape_width + shift_w >= padding_left + input_width) { + // This may not execute correctly when pad_value != 0 and T != uint8. memset(out, pad_value, depth * sizeof(T)); } else { const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, + input1_data + + Offset(input1_shape, input_batch, (out_h * block_shape_height + shift_h) - padding_top, - input_batch); + (out_w * block_shape_width + shift_w) - padding_left, 0); memcpy(out, in, depth * sizeof(T)); } } @@ -3513,30 +3589,63 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. template <typename T> inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, const Dims<4>& block_shape_dims, const int32* paddings_data, const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { - SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims, - paddings_data, paddings_dims, output_data, output_dims, 0); + const Dims<4>& output_dims, + const int32_t pad_value) { + tflite::SpaceToBatchParams op_params; + op_params.output_offset = pad_value; + + SpaceToBatchND(op_params, DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(paddings_dims), paddings_data, + DimsToShape(output_dims), output_data); } +// Legacy if no good reason to have signature with pad_value=0. template <typename T> -inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, - T* output_data, const Dims<4>& output_dims) { - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToBatchParams op_params; + op_params.output_offset = 0; + + SpaceToBatchND(op_params, DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(paddings_dims), paddings_data, + DimsToShape(output_dims), output_data); +} + +template <typename T> +inline void BatchToSpaceND( + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* crops_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + const int block_shape_width = block_shape_data[1]; const int block_shape_height = block_shape_data[0]; const int crops_top = crops_data[0]; @@ -3558,14 +3667,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, if (out_w < 0 || out_w >= output_width) { continue; } - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); - const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0); + const T* in = + input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0); memcpy(out, in, depth * sizeof(T)); } } } } +// Legacy Dims<4>. +template <typename T> +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* crops_data, const Dims<4>& crops_dims, + T* output_data, const Dims<4>& output_dims) { + BatchToSpaceND(DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(crops_dims), crops_data, DimsToShape(output_dims), + output_data); +} + // There are two versions of pad: Pad and PadV2. In PadV2 there is a second // scalar input that provides the padding value. Therefore pad_value_ptr can be // equivalent to a simple input1_data. For Pad, it should point to a zero diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 27b78aa225..2603ed2eb7 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -745,7 +745,7 @@ struct ConvParams { }; struct DepthToSpaceParams { - int16 block_size; + int32 block_size; }; struct DepthwiseParams { @@ -871,8 +871,13 @@ struct SoftmaxParams { int diff_min; }; +struct SpaceToBatchParams { + // "Zero" padding for uint8 means padding with the output offset. + int32 output_offset; +}; + struct SpaceToDepthParams { - int16 block_size; + int32 block_size; }; struct SplitParams { diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index ba251c451e..74dc3f25f9 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -37,7 +37,7 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // Which kernel type to use. Full kernel (20 inputs) or basic kernel // (5 inputs). TfLiteLSTMKernelType kernel_type; @@ -47,7 +47,7 @@ struct OpData { int scratch_tensor_index; }; -// For full inputs kernel (18 or 20 inputs). +// For full inputs kernel (20-inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -81,19 +81,13 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional -// If the node has 20 inputs, the following 2 tensors are used as state tensors. -// These are defined as variable tensors, and will be modified by this op. +// These state tensors are defined as variable tensors, and will be modified by +// this op. constexpr int kInputActivationStateTensor = 18; constexpr int kInputCellStateTensor = 19; // Output tensors. -// * If the node has 18 inputs, these 2 tensors are used as state tensors. -// * If the node has 20 inputs, these 2 tensors are ignored. -// TODO(ycling): Make the 2 output state tensors optional, and propagate the -// state to output tensors when the 2 tensors present. -constexpr int kOutputStateTensor = 0; -constexpr int kCellStateTensor = 1; -constexpr int kOutputTensor = 2; +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); @@ -258,30 +252,12 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast<OpData*>(node->user_data); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); - - // True if the node is using input variable state tensors. It means: - // * The state tensors are defined as inputs. In this case it would be the - // 19th and 20th input tensors. - // * Otherwise, the output tensors are used to store states. - bool use_input_variable_states; - if (node->inputs->size == 20) { - use_input_variable_states = true; - op_data->activation_state_tensor_index = - node->inputs->data[kInputActivationStateTensor]; - op_data->cell_state_tensor_index = - node->inputs->data[kInputCellStateTensor]; - } else if (node->inputs->size == 18) { - use_input_variable_states = false; - op_data->activation_state_tensor_index = - node->outputs->data[kOutputStateTensor]; - op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; - } else { - context->ReportError( - context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", - node->inputs->size); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 20); + + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor]; // Inferring batch size, number of outputs and number of cells from the // input tensors. @@ -316,31 +292,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* cell_state = &context->tensors[op_data->cell_state_tensor_index]; - if (use_input_variable_states) { - // Check the shape of input state tensors. - // These tensor may be 1D or 2D. It's fine as long as the total size is - // correct. - TF_LITE_ENSURE_EQ(context, NumElements(activation_state), - n_batch * n_output); - TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); - } else { - // If the state tensors are outputs, this function takes the - // responsibility to resize the state tensors. - TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); - activation_state_size->data[0] = n_batch; - activation_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, - activation_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); - // Mark state tensors as persistent tensors. - activation_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - } + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 0266f5fe57..e7ddfceb45 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) .Union()); + BuildInterpreter(input_shapes); } @@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast<float*>(begin), const_cast<float*>(end)); @@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); } @@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -698,14 +669,10 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { +class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest { void SetUp() override { input_to_input_weights_ = { 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, @@ -1304,7 +1271,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { } }; -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1362,14 +1329,10 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 1c728a4733..90a915bb02 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -101,8 +101,6 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, @@ -180,22 +178,6 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, float* begin, float* end) { PopulateTensor(input_, offset, begin, end); } @@ -238,8 +220,6 @@ class LSTMOpModel : public SingleOpModel { int input_cell_state_; int output_; - int output_state_; - int cell_state_; int n_batch_; int n_input_; @@ -324,10 +304,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { lstm.SetCellToOutputWeights( {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - // Verify the model by unpacking it. lstm.Verify(); } diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 10d1fcc5bc..341fd14127 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -113,6 +113,7 @@ TfLiteRegistration* Register_ONE_HOT(); TfLiteRegistration* Register_LOGICAL_OR(); TfLiteRegistration* Register_LOGICAL_AND(); TfLiteRegistration* Register_LOGICAL_NOT(); +TfLiteRegistration* Register_UNPACK(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -235,6 +236,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); + AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc new file mode 100644 index 0000000000..4998f88b41 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unpack.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unpack { +namespace { + +constexpr int kInputTensor = 0; + +// Op data for unpack op. +struct OpData { + int num; + int axis; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->axis = 0; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast<OpData*>(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, NumDimensions(input) <= 4); + TF_LITE_ENSURE(context, NumDimensions(input) > 1); + TF_LITE_ENSURE(context, NumDimensions(input) > data->axis); + // TODO(renjieliu): Support negative axis. + TF_LITE_ENSURE(context, data->axis >= 0); + if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + + const TfLiteIntArray* input_shape = input->dims; + // Num should be equal to the shape[axis]. + // Resize outputs. rank will be R - 1. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1); + int o = 0; + for (int index = 0; index < NumDimensions(input); ++index) { + if (index != data->axis) { + output_shape->data[o++] = input_shape->data[index]; + } + } + + TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]); + for (int i = 0; i < data->num; ++i) { + TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape); + TfLiteTensor* output = GetOutput(context, node, i); + TF_LITE_ENSURE_EQ(context, output->type, input->type); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output, copied_output_shape)); + } + + TfLiteIntArrayFree(output_shape); + return kTfLiteOk; +} + +template <typename T> +void UnpackImpl(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* input, int output_count, int axis) { + VectorOfTensors<T> all_outputs(*context, *node->outputs); + reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input), + NumDimensions(input), output_count, + all_outputs.data(), **all_outputs.dims()); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast<OpData*>(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + switch (input->type) { + case kTfLiteFloat32: { + UnpackImpl<float>(context, node, input, data->num, data->axis); + break; + } + case kTfLiteInt32: { + UnpackImpl<int32_t>(context, node, input, data->num, data->axis); + break; + } + default: { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + } + + return kTfLiteOk; +} +} // namespace +} // namespace unpack + +TfLiteRegistration* Register_UNPACK() { + static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare, + unpack::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/contrib/lite/kernels/unpack_test.cc new file mode 100644 index 0000000000..4efc92a0fd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unpack_test.cc @@ -0,0 +1,225 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <vector> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +template <typename T> +class UnpackOpModel : public SingleOpModel { + public: + UnpackOpModel(const TensorData& input, int axis) { + CHECK_LE(axis, input.shape.size()); + const int num_outputs = input.shape[axis]; + input_ = AddInput(input); + for (int i = 0; i < num_outputs; ++i) { + outputs_.push_back(AddOutput(input.type)); + } + SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions, + CreatePackOptions(builder_, num_outputs, axis).Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list<T> data) { + PopulateTensor<T>(input_, data); + } + + std::vector<std::vector<T>> GetOutputDatas() { + std::vector<std::vector<T>> output_datas; + for (const int output : outputs_) { + std::cerr << "the output is " << output << std::endl; + output_datas.push_back(ExtractVector<T>(output)); + } + return output_datas; + } + + std::vector<std::vector<int>> GetOutputShapes() { + std::vector<std::vector<int>> output_shapes; + for (const int output : outputs_) { + output_shapes.push_back(GetTensorShape(output)); + } + return output_shapes; + } + + private: + int input_; + std::vector<int> outputs_; +}; + +// float32 tests. +TEST(UnpackOpTest, FloatThreeOutputs) { + UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 3); + EXPECT_THAT(output_shapes[0], ElementsAre(2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + + // Check outputs values. + const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 3); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2)); + EXPECT_THAT(output_datas[1], ElementsAre(3, 4)); + EXPECT_THAT(output_datas[2], ElementsAre(5, 6)); +} + +TEST(UnpackOpTest, FloatThreeOutputsAxisOne) { + UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 1); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(3)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + + // Check outputs values. + const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6)); +} + +TEST(UnpackOpTest, FloatOneOutput) { + UnpackOpModel<float> model({TensorType_FLOAT32, {1, 6}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 1); + EXPECT_THAT(output_shapes[0], ElementsAre(6)); + + // Check outputs values. + const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 1); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST(UnpackOpTest, FloatThreeDimensionsOutputs) { + UnpackOpModel<float> model({TensorType_FLOAT32, {2, 2, 2}}, 2); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(2, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2, 2)); + + // Check outputs values. + const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8)); +} + +// int32 tests. +TEST(UnpackOpTest, IntThreeOutputs) { + UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 3); + EXPECT_THAT(output_shapes[0], ElementsAre(2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + + // Check outputs values. + const std::vector<std::vector<int32_t>>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 3); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2)); + EXPECT_THAT(output_datas[1], ElementsAre(3, 4)); + EXPECT_THAT(output_datas[2], ElementsAre(5, 6)); +} + +TEST(UnpackOpTest, IntThreeOutputsAxisOne) { + UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 1); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(3)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + + // Check outputs values. + const std::vector<std::vector<int32_t>>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6)); +} + +TEST(UnpackOpTest, IntOneOutput) { + UnpackOpModel<int32_t> model({TensorType_INT32, {1, 6}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 1); + EXPECT_THAT(output_shapes[0], ElementsAre(6)); + + // Check outputs values. + const std::vector<std::vector<int32_t>>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 1); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST(UnpackOpTest, IntThreeDimensionsOutputs) { + UnpackOpModel<int32_t> model({TensorType_INT32, {2, 2, 2}}, 2); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + model.Invoke(); + + // Check outputs shapes. + const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(2, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2, 2)); + + // Check outputs values. + const std::vector<std::vector<int32_t>>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh index b58ae26601..6195426d6d 100755 --- a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh +++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== +# TODO(ycling): Refactoring - Move this script into `tools/make`. set -e echo "Starting" @@ -32,7 +33,7 @@ echo "Headers, populating: TensorFlow Lite" cd $TFLITE_DIR/../../.. find tensorflow/contrib/lite -name '*.h' \ - -not -path 'tensorflow/contrib/lite/downloads/*' \ + -not -path 'tensorflow/contrib/lite/tools/*' \ -not -path 'tensorflow/contrib/lite/examples/*' \ -not -path 'tensorflow/contrib/lite/gen/*' \ -not -path 'tensorflow/contrib/lite/toco/*' \ @@ -44,7 +45,7 @@ tar xf tmp.tar rm -f tmp.tar echo "Headers, populating: Flatbuffer" -cd $TFLITE_DIR/downloads/flatbuffers/include/ +cd $TFLITE_DIR/tools/make/downloads/flatbuffers/include/ find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - cd $FW_DIR_TFLITE_HDRS tar xf tmp.tar @@ -57,7 +58,7 @@ cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tens $FW_DIR_TFLITE echo "Copying static libraries" -cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \ +cp $TFLITE_DIR/tools/make/gen/lib/libtensorflow-lite.a \ $FW_DIR_TFLITE/tensorflow_lite # This is required, otherwise they interfere with the documentation of the diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 7ca12cb841..da3ed42e20 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -745,6 +745,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = static_cast<void*>(params); break; } + case BuiltinOperator_UNPACK: { + TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>(); + if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { + params->num = unpack_params->num(); + params->axis = unpack_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } // Below are the ops with no builtin_data strcture. case BuiltinOperator_BATCH_TO_SPACE_ND: @@ -790,7 +799,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOGICAL_OR: case BuiltinOperator_LOGICAL_AND: case BuiltinOperator_LOGICAL_NOT: - case BuiltinOperator_UNPACK: case BuiltinOperator_FLOOR_DIV: case BuiltinOperator_REDUCE_ANY: break; diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc index fad39bee9e..8ecf0b6154 100644 --- a/tensorflow/contrib/lite/models/speech_test.cc +++ b/tensorflow/contrib/lite/models/speech_test.cc @@ -126,7 +126,7 @@ TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, SpeakerIdOkGoogleTest) { +TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv", @@ -139,7 +139,7 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, AsrAmTest) { +TEST_P(SpeechTest, DISABLED_AsrAmTest) { std::stringstream os; ASSERT_TRUE( ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv", @@ -156,7 +156,7 @@ TEST_P(SpeechTest, AsrAmTest) { // through the interpreter and stored the sum of all the output, which was them // compared for correctness. In this test we are comparing all the intermediate // results. -TEST_P(SpeechTest, AsrLmTest) { +TEST_P(SpeechTest, DISABLED_AsrLmTest) { std::ifstream in_file; testing::TfLiteDriver test_driver(/*use_nnapi=*/false); ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file)); @@ -165,7 +165,7 @@ TEST_P(SpeechTest, AsrLmTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, EndpointerTest) { +TEST_P(SpeechTest, DISABLED_EndpointerTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv", @@ -178,7 +178,7 @@ TEST_P(SpeechTest, EndpointerTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, TtsTest) { +TEST_P(SpeechTest, DISABLED_TtsTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite", "speech_tts_model_in.csv", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 599c82940e..a329bb3a25 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2378,7 +2378,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], - "split_tflite_lstm_inputs": [True, False], + "split_tflite_lstm_inputs": [False], }, ] @@ -3149,6 +3149,36 @@ def make_pack_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_unpack_tests(zip_path): + """Make a set of tests to do unstack.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]], + "axis": [0, 1, 2, 3], + }] + + def get_valid_axis(parameters): + """Return a tweaked version of 'axis'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + while axis > len(shape) - 1: + axis -= 1 + return axis + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input"), shape=parameters["base_shape"]) + outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters)) + return [input_tensor], outs + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(np.float32, shape=parameters["base_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def _make_logical_tests(op): """Make a set of tests to do logical operations.""" diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 4dacf9c84b..1836eb53b9 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() { void TfLiteDriver::ResetLSTMStateTensors() { interpreter_->ResetVariableTensorsToZero(); - - // Below is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Remove the code below after nobody is using the 18-inputs - // definition. - for (auto node_index : interpreter_->execution_plan()) { - const auto& node_and_reg = interpreter_->node_and_registration(node_index); - const auto& node = node_and_reg->first; - const auto& registration = node_and_reg->second; - - if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { - const auto* params = - reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data); - if (params->kernel_type == kTfLiteLSTMFullKernel && - node.inputs->size == 18 && node.outputs->size >= 2) { - // The first 2 outputs of LSTM are state tensors. - for (int i = 0; i < 2; ++i) { - int node_index = node.outputs->data[i]; - ResetTensor(node_index); - } - } - } - } } } // namespace testing diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f489c5ac65..94602445c2 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1967,6 +1967,20 @@ void ConvertCTCBeamSearchDecoderOperator( (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); } +void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node(); + unpack_op->set_op(op_name); + unpack_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *unpack_op->add_input() = src_op.inputs[0]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*unpack_op->mutable_attr())["T"].set_type(data_type); + (*unpack_op->mutable_attr())["num"].set_i(src_op.num); + (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2228,6 +2242,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertCTCBeamSearchDecoderOperator( model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op), "CTCBeamSearchDecoder", tensorflow_graph); + } else if (src_op.type == OperatorType::kUnpack) { + ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op), + "Unpack", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index c8310161cb..323eefcd3a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -227,6 +227,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { ArrayDataType::kFloat; break; } + case OperatorType::kUnpack: { + CHECK_EQ(op->inputs.size(), 1); + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size; ++i) { + model->GetArray(op->outputs[i]).data_type = + model->GetArray(op->inputs[0]).data_type; + } + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 91e290439a..fa2be961f5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1629,6 +1629,32 @@ void ProcessOneHotOperator(Model* model, OneHotOperator* op) { } } +void ProcessUnpackOperator(Model* model, UnpackOperator* op) { + CHECK_EQ(op->inputs.size(), 1); + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + + const std::vector<int>& input_dims = input_array.shape().dims(); + std::vector<int> output_dims; + + output_dims.reserve(input_dims.size() - 1); + for (int i = 0; i < input_dims.size(); ++i) { + if (i != op->axis) { + output_dims.push_back(input_dims[i]); + } + } + for (const string& output_name : op->outputs) { + auto& output_array = model->GetArray(output_name); + if (output_array.has_shape()) { + return; + } + *output_array.mutable_shape()->mutable_dims() = output_dims; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1880,6 +1906,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kOneHot: ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op)); break; + case OperatorType::kUnpack: + ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index b7fffbce22..0e04ee4ccb 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1576,6 +1576,26 @@ tensorflow::Status ConvertPackOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertUnpackOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Unpack"); + auto op = absl::make_unique<UnpackOperator>(); + const int num_inputs = GetInputsCount(node, tf_import_flags); + QCHECK_EQ(num_inputs, 1); + op->inputs.push_back(node.input(0)); + op->num = GetIntAttr(node, "num"); + op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; + op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T")); + + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 1; i < op->num; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i)); + } + model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); +} + // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't // be able to fully support such graphs, including performing inference, @@ -2020,6 +2040,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"TopK", ConvertTopKV2Operator}, {"TopKV2", ConvertTopKV2Operator}, {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>}, + {"Unpack", ConvertUnpackOperator}, }); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 412e14c4ad..3a909c3d8e 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -149,6 +149,7 @@ enum class OperatorType : uint8 { kLogicalNot, kLogicalOr, kCTCBeamSearchDecoder, + kUnpack, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1828,6 +1829,20 @@ struct LogicalOrOperator : Operator { LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {} }; +// Unpack operator: +// +// Inputs: +// Inputs[0]: required: A boolean input tensor. +// Inputs[1]: required: reduction_indices. +// +// TensorFlow equivalent: tf.unstack. +struct UnpackOperator : Operator { + UnpackOperator() : Operator(OperatorType::kUnpack) {} + int num; + int axis; + ArrayDataType dtype = ArrayDataType::kNone; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index dcb5fff39f..e9383098cc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1110,6 +1110,24 @@ class CTCBeamSearchDecoder int GetVersion(const Operator& op) const override { return 1; } }; +class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions, + ::tflite::BuiltinOptions_UnpackOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis); + } + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->num = options.num(); + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -1353,6 +1371,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); + ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK, + OperatorType::kUnpack)); // Custom Operators. ops.push_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fc854461b4..bb0b457483 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -476,6 +476,16 @@ TEST_F(OperatorTest, BuiltinOneHot) { EXPECT_EQ(op.axis, output_toco_op->axis); } +TEST_F(OperatorTest, BuiltinUnpack) { + UnpackOperator op; + op.num = 5; + op.axis = 2; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op); + EXPECT_EQ(op.num, output_toco_op->num); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { CTCBeamSearchDecoderOperator op; op.beam_width = 3; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 3a4542f522..6ab93d9316 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -405,6 +405,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(LogicalNot) HANDLE_OPERATORTYPENAME_CASE(LogicalOr) HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) + HANDLE_OPERATORTYPENAME_CASE(Unpack) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD new file mode 100644 index 0000000000..01fbce0ac7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/BUILD @@ -0,0 +1,11 @@ +# TODO(suharshs): Write quantize_weights tests that use small exportable files. +# Then we can remove this file. +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc new file mode 100644 index 0000000000..0758514e39 --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -0,0 +1,280 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { +namespace optimize { + +namespace { + +// The minimum number of elements a weights array must have to be quantized +// by this transformation. +// TODO(suharshs): Make this configurable. +const int kWeightsMinSize = 1024; + +// Nudge min and max so that floating point 0 falls exactly on a quantized +// value, returning the nudges scale and zero_point. +// +// Although this code originates from FakeQuantization in quantized training, +// we may deviate from that implementation as we please since we do not fine +// tune the weights with quantized training. +void GetQuantizationParams(const float min, const float max, + const int quant_min, const int quant_max, + QuantizationParametersT* quantization_params) { + // Adjust the boundaries to guarantee 0 is included. + const float quant_min_float = std::min(static_cast<float>(quant_min), 0.0f); + const float quant_max_float = std::max(static_cast<float>(quant_max), 0.0f); + const float scale = (max - min) / (quant_max_float - quant_min_float); + const float zero_point_from_min = quant_min_float - min / scale; + int64_t zero_point; + if (zero_point_from_min < quant_min_float) { + zero_point = static_cast<int64_t>(quant_min); + } else if (zero_point_from_min > quant_max_float) { + zero_point = static_cast<int64_t>(quant_max); + } else { + zero_point = static_cast<int64_t>(std::round(zero_point_from_min)); + } + quantization_params->scale = {scale}; + quantization_params->zero_point = {zero_point}; +} + +// Returns the number of elements in tensor. +uint64 NumElements(const TensorT* tensor) { + if (tensor->shape.empty()) { + LOG(FATAL) << "Tensor has no shape information."; + } + uint64 num_elements = 1; + for (const uint64 dim : tensor->shape) { + num_elements *= dim; + } + return num_elements; +} + +uint64 CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph, + int32_t tensor_idx) { + uint64 count = 0; + for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) { + const OperatorT* op = subgraph->operators[op_idx].get(); + if (op == nullptr) { + continue; + } + for (int i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == tensor_idx) { + count++; + } + } + } + return count; +} + +// Returns true if the Operator's weight tensor should be quantized. +bool GetQuantizableTensorFromOperator(const ModelT* model, const OperatorT* op, + TensorT** tensor, int32_t* tensor_idx, + int32_t* op_input_index) { + SubGraphT* subgraph = model->subgraphs.at(0).get(); + const BuiltinOperator op_code = + model->operator_codes[op->opcode_index]->builtin_code; + + if (op_code == BuiltinOperator_CONV_2D || + op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + op_code == BuiltinOperator_FULLY_CONNECTED || + op_code == BuiltinOperator_SVDF) { + *op_input_index = 1; + } else if (op_code == BuiltinOperator_LSTM) { + // TODO(suharshs): Add RNN, and sequential/bidi versions. + *op_input_index = 2; + } else { + return false; + } + *tensor_idx = op->inputs[*op_input_index]; + + // TODO(suharshs): Support shared weights, i.e. If two tensors share the + // same weight array, things may break. (i.e. SSD object detection) + if (CountTensorConsumers(model, subgraph, *tensor_idx) != 1) { + LOG(INFO) << "Skipping quantization of tensor that is shared between " + "multiple multiple operations."; + return false; + } + + *tensor = subgraph->tensors[*tensor_idx].get(); + + if ((*tensor)->type != TensorType_FLOAT32) { + LOG(INFO) << "Skipping quantization of tensor that is not type float."; + return false; + } + const uint64 num_elements = NumElements(*tensor); + if (num_elements < kWeightsMinSize) { + LOG(INFO) << "Skipping quantization of tensor because it has fewer than " + << kWeightsMinSize << " elements (" << num_elements << ")."; + return false; + } + + return true; +} + +// Quantizes tensor using asymmetric quantization with the min and max elements +// of the tensor. This is needed to pass to Dequantize operations. +TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { + BufferT* buffer = model->buffers[tensor->buffer].get(); + float* float_data = reinterpret_cast<float*>(buffer->data.data()); + const uint64 num_elements = NumElements(tensor); + LOG(INFO) << "Quantizing tensor with " << num_elements << " elements."; + + // Compute the quantization params. + float min_value = *std::min_element(float_data, float_data + num_elements); + float max_value = *std::max_element(float_data, float_data + num_elements); + GetQuantizationParams(min_value, max_value, 0, 255, + tensor->quantization.get()); + + // Quantize the buffer. + std::vector<uint8_t> quantized_buffer; + quantized_buffer.resize(num_elements); + const double inverse_scale = 1. / tensor->quantization->scale[0]; + for (std::size_t i = 0; i < num_elements; i++) { + const float src_val = float_data[i]; + double scaled_val; + if (tensor->quantization->scale[0] == 0) { + scaled_val = tensor->quantization->zero_point[0]; + } else { + scaled_val = + tensor->quantization->zero_point[0] + inverse_scale * src_val; + } + uint8_t integer_val = static_cast<uint8_t>(std::round(scaled_val)); + quantized_buffer[i] = integer_val; + } + model->buffers[tensor->buffer]->data = quantized_buffer; + + // Update the tensor type. + tensor->type = TensorType_UINT8; + + return kTfLiteOk; +} + +// Returns the index of the Dequantize op_code. +// If a Dequantize op_code doesn't exist, adds it and returns its index. +int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) { + for (int i = 0; i < model->operator_codes.size(); ++i) { + if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) { + return i; + } + } + model->operator_codes.push_back(std::make_unique<OperatorCodeT>()); + int op_code_idx = model->operator_codes.size() - 1; + model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE; + // TODO(suharshs): How should the version be set in this op_code? + + // Return the index of the newly placed OperatorCodeT. + return op_code_idx; +} + +// Creates a Dequantize OperatorT object. +void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op, + int32_t input, int32_t output) { + OperatorT* op_raw = new OperatorT; + op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model); + op_raw->inputs = {input}; + op_raw->outputs = {output}; + + op->reset(op_raw); +} + +// Create a new TensorT object. +void MakeTensor(const string& name, const std::vector<int32_t>& shape, + std::unique_ptr<TensorT>* tensor) { + TensorT* tensor_raw = new TensorT; + tensor_raw->name = name; + tensor_raw->shape = shape; + + tensor->reset(tensor_raw); +} + +} // namespace + +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr<ModelT> model; + model.reset(input_model->UnPack()); + + // TODO(suharshs): When models support multiple subgraphs, add support. + if (model->subgraphs.size() != 1) { + LOG(ERROR) << "Quantize weights tool only supports tflite models with one " + "subgraph."; + return kTfLiteError; + } + + SubGraphT* subgraph = model->subgraphs.at(0).get(); + + std::vector<std::unique_ptr<OperatorT>> new_operators; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + + TensorT* tensor; + // The index of the weight tensor in subgraph->tensors. + int32_t tensor_idx; + int32_t op_input_idx; // The index of tensor_idx in the op->inputs. + // TODO(suharshs): Support hybrid ops that require symmetric quantization. + if (GetQuantizableTensorFromOperator(model.get(), op, &tensor, &tensor_idx, + &op_input_idx)) { + // Quantize the tensors. + TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(model.get(), tensor)); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr<TensorT> dequantize_output; + MakeTensor(tensor->name + "_dequantize", tensor->shape, + &dequantize_output); + int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr<OperatorT> dequantize_op; + MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of tensor_idx to dequantize_output_idx. + op->inputs[op_input_idx] = dequantize_output_idx; + // Insert the updated op. + new_operators.push_back(std::move(subgraph->operators[i])); + + // Insert the newly created Dequantize operation. + new_operators.push_back(std::move(dequantize_op)); + } else { + // If this tensor wasn't quantizable, just copy the op over as-is. + new_operators.push_back(std::move(subgraph->operators[i])); + } + } + // At this point all unique_ptrs in the original operators are invalid, and + // we need to replace it with the new_operators vector. + subgraph->operators = std::move(new_operators); + + flatbuffers::Offset<Model> output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return kTfLiteOk; +} + +} // namespace optimize +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h new file mode 100644 index 0000000000..a408c1662d --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ + +#include <memory> +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace optimize { + +// Quantizes input_model and populates the provided builder with the new model. +// +// A tflite::Model can be obtained from the builder with: +// const uint8_t* buffer = builder->GetBufferPointer(); +// tflite::Model* model = GetModel(buffer); +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model); + +} // namespace optimize +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc new file mode 100644 index 0000000000..0e0676e5ff --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" + +#include <memory> + +#include "flatbuffers/flexbuffers.h" +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace optimize { +namespace { + +class QuantizeWeightsTest : public ::testing::Test { + protected: + int GetElementsNum(const TensorT* tensor) { + int tensor_size = 1; + for (const int dim : tensor->shape) { + tensor_size *= dim; + } + return tensor_size; + } + + const OperatorT* GetOpWithOutput(const SubGraphT* subgraph, + int32_t output_tensor_idx) { + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + if (std::find(op->outputs.begin(), op->outputs.end(), + output_tensor_idx) != op->outputs.end()) { + return op; + } + } + return nullptr; + } + + void CheckWeights(const Model* model_packed) { + std::unique_ptr<ModelT> model; + model.reset(model_packed->UnPack()); + + SubGraphT* subgraph = model->subgraphs.at(0).get(); + + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + const BuiltinOperator op_code = + model->operator_codes[op->opcode_index]->builtin_code; + + // These are the operations that should be quantized. + int32_t tensor_idx; + if (op_code == BuiltinOperator_CONV_2D || + op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + op_code == BuiltinOperator_FULLY_CONNECTED) { + tensor_idx = op->inputs[1]; + } else if (op_code == BuiltinOperator_LSTM) { + // TODO(suharshs): Add tests for LSTMs. + tensor_idx = op->inputs[1]; + } else { + continue; + } + const TensorT* tensor = subgraph->tensors[tensor_idx].get(); + int tensor_size = GetElementsNum(tensor); + // If the tensor_size is less than 1024 we expect the tensor to remain + // unquantized. + if (tensor_size < 1024) { + ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name; + const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx); + // The weight tensor should not come from a dequantize op. + ASSERT_TRUE(preceding_op == nullptr); + } else { + // The input to the op should still be float. + ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name; + const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx); + ASSERT_TRUE(preceding_op != nullptr); + // The float input should be the dequantize output. + ASSERT_TRUE( + model->operator_codes[preceding_op->opcode_index]->builtin_code == + BuiltinOperator_DEQUANTIZE); + // Finally, ensure that the input to the dequantize operation is + // quantized. + ASSERT_TRUE(subgraph->tensors[preceding_op->inputs[0]]->type == + TensorType_UINT8); + // TODO(suharshs): Add more rigorous testing for the numerical values in + // the tensors. + } + } + } +}; + +TEST_F(QuantizeWeightsTest, SimpleTest) { + string model_path = + "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "mobilenet_v1_0.25_128.tflite"; + std::unique_ptr<FlatBufferModel> input_fb = + FlatBufferModel::BuildFromFile(model_path.data()); + const Model* input_model = input_fb->GetModel(); + + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(QuantizeWeights(&builder, input_model), kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + + CheckWeights(output_model); +} + +// TODO(suharshs): Add tests that run the resulting model. + +} // namespace +} // namespace optimize +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: FLAGS_logtostderr = true; + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index a328670526..bbf5d3f30c 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -2532,7 +2532,8 @@ def sparse_recall_at_top_k(labels, name=name_scope) -def _compute_recall_at_precision(tp, fp, fn, precision, name): +def _compute_recall_at_precision(tp, fp, fn, precision, name, + strict_mode=False): """Helper function to compute recall at a given `precision`. Args: @@ -2541,17 +2542,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name): fn: The number of false negatives. precision: The precision for which the recall will be calculated. name: An optional variable_scope name. + strict_mode: If true and there exists a threshold where the precision is + no smaller than the target precision, return the corresponding recall at + the threshold. Otherwise, return 0. If false, find the threshold where the + precision is closest to the target precision and return the recall at the + threshold. Returns: The recall at a given `precision`. """ precisions = math_ops.div(tp, tp + fp + _EPSILON) - tf_index = math_ops.argmin( - math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + if not strict_mode: + tf_index = math_ops.argmin( + math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + # Now, we have the implicit threshold, so compute the recall: + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) + else: + # We aim to find the threshold where the precision is minimum but no smaller + # than the target precision. + # The rationale: + # 1. Compute the difference between precisions (by different thresholds) and + # the target precision. + # 2. Take the reciprocal of the values by the above step. The intention is + # to make the positive values rank before negative values and also the + # smaller positives rank before larger positives. + tf_index = math_ops.argmax( + math_ops.div(1.0, precisions - precision + _EPSILON), + 0, + output_type=dtypes.int32) + + def _return_good_recall(): + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) - # Now, we have the implicit threshold, so compute the recall: - return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, - name) + return control_flow_ops.cond(precisions[tf_index] >= precision, + _return_good_recall, lambda: .0) def recall_at_precision(labels, @@ -2561,7 +2587,8 @@ def recall_at_precision(labels, num_thresholds=200, metrics_collections=None, updates_collections=None, - name=None): + name=None, + strict_mode=False): """Computes `recall` at `precision`. The `recall_at_precision` function creates four local variables, @@ -2593,6 +2620,11 @@ def recall_at_precision(labels, updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. + strict_mode: If true and there exists a threshold where the precision is + above the target precision, return the corresponding recall at the + threshold. Otherwise, return 0. If false, find the threshold where the + precision is closest to the target precision and return the recall at the + threshold. Returns: recall: A scalar `Tensor` representing the recall at the given @@ -2621,10 +2653,11 @@ def recall_at_precision(labels, predictions, labels, thresholds, weights) recall = _compute_recall_at_precision(values['tp'], values['fp'], - values['fn'], precision, 'value') + values['fn'], precision, 'value', + strict_mode) update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], update_ops['fn'], precision, - 'update_op') + 'update_op', strict_mode) if metrics_collections: ops.add_to_collections(metrics_collections, recall) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 1c2c17960a..024bd54912 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -3467,6 +3467,60 @@ class RecallAtPrecisionTest(test.TestCase): self.assertAlmostEqual(target_recall, sess.run(update_op)) self.assertAlmostEqual(target_recall, recall.eval()) + def _test_strict_mode(self, strict_mode, target_precision, expected_recall): + num_thresholds = 11 + predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1] + labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1] + # Resulting thresholds and the corresponding precision and recall values at + # each threshold: + # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9] + # precisions: [0.3 0.2 0.1 0 0 0 0 0 0] + # recalls: [1.0 0.7 0.3 0 0 0 0 0 0] + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + labels, + predictions, + num_thresholds=num_thresholds, + precision=target_precision, + strict_mode=strict_mode) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expected_recall, sess.run(update_op)) + self.assertAlmostEqual(expected_recall, recall.eval()) + + def testStrictMode_Off(self): + # strict_mode is turned off and return the recall at the threshold where the + # precision (0.3) is closest to target precision (0.9). The recall + # corresponding to the threshold is 1.0. + self._test_strict_mode( + strict_mode=False, target_precision=0.9, expected_recall=1.0) + + def testStrictMode_OnAndFail(self): + # strict_mode is turned on and we fail to reach the target precision at any + # threshold. + # Target precision: 0.9 + # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9] + # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1] + # Max index: 3 and corresponding precision is: 0 which is smaller than + # target precsion 0.9. As a result, the expected recall is 0. + self._test_strict_mode( + strict_mode=True, target_precision=0.9, expected_recall=.0) + + def testStrictMode_OnAndSucceed(self): + # strict_mode is on and we can reach the target precision at certain + # threshold. + # Target precision: 0.2 + # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2] + # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0] + # Max index: 1 and corresponding precision is: 0.2 which is no smaller than + # target precsion 0.2. In this case, we return the recall at index 1, which + # is 2.0/3 (0.7). + self._test_strict_mode( + strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3) + class PrecisionAtRecallTest(test.TestCase): diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc index d43884481a..99c5800391 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc @@ -130,7 +130,11 @@ void TensorDataSet::RandomSample(int example, num_total_features += num_sparse; } } - int rand_feature = rng_->Uniform(num_total_features); + int rand_feature = 0; + { + mutex_lock lock(mu_); + rand_feature = rng_->Uniform(num_total_features); + } if (rand_feature < available_features_.size()) { // it's dense. *feature_id = available_features_[rand_feature]; *type = input_spec_.GetDenseFeatureType(rand_feature); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index 95f75b4d7e..4945b53007 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -25,6 +25,7 @@ #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace tensorforest { @@ -120,6 +121,8 @@ class TensorDataSet { int32 split_sampling_random_seed_; std::unique_ptr<random::PhiloxRandom> single_rand_; std::unique_ptr<random::SimplePhilox> rng_; + // Mutex for using random number generator. + mutable mutex mu_; }; } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index a0fc3e43a9..122a67a407 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -279,6 +279,7 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 0f5abe6898..c98b07ad8b 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index a5e8277ba5..1d1cb48e8e 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -111,17 +111,24 @@ def reset_tpu_sessions(): # Work-around dependency cycle between DistributionStrategy and TPU lib. -def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name +def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name """Construct a TPUDistributionStrategy.""" from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top - # TODO -- remove this when TPUStrategy API is consistent (b/112705069) + # TODO(b/112705069): Remove this when TPUStrategy API is consistent. + # We are including this for (a) backwards compatibility for open sourced + # releases of TensorFlow and (b) to work around a circular dependency + # where keras_support and tpu_strategy depends on each other. Once we release + # a final version and remove support for the old API, this will be deleted. + # (See bug above for more details) if tpu_cluster_resolver is None: tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__) - if len(args) == 3: + if len(args) == 4: logging.info('Detected new TPUStrategy API.') - return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1) + return tpu_strategy.TPUStrategy(tpu_cluster_resolver, + steps_per_run=1, + num_cores=num_cores) else: logging.info('Detected old TPUStrategy API.') strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 60830b7d60..836c3ce34e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -375,6 +375,7 @@ cc_library( ":lib_platform", ":platform_base", "//tensorflow/core/platform/default/build_config:port", + "@com_google_absl//absl/base", "@snappy", ], ) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index b5a51d2526..97b6971c5b 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/denormal.h" +#include "tensorflow/core/platform/setround.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -553,6 +555,11 @@ bool ReplaceTensorWithConstant( Status ConstantFold(const ConstantFoldingOptions& opts, FunctionLibraryRuntime* function_library, Env* env, Device* partition_device, Graph* graph, bool* was_mutated) { + // TensorFlow flushes denormals to zero and rounds to nearest, so we do + // the same here. + port::ScopedFlushDenormal flush; + port::ScopedSetRound round(FE_TONEAREST); + DumpGraph("Before", graph); ConstantFoldNameGenerator generate_new_name = opts.generate_new_name; if (generate_new_name == nullptr) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index ea1b04feeb..4bc88ffc8c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/tensor.h" @@ -36,4 +37,12 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done); } +Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream, + std::function<void()> func) { + const DeviceBase::GpuDeviceInfo* gpu_info = + device->tensorflow_gpu_device_info(); + gpu_info->event_mgr->ThenExecute(stream, func); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index 8370c63842..3603808152 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -60,6 +60,9 @@ class GPUDeviceContext : public DeviceContext { void MaintainLifetimeOnStream(const Tensor* t, se::Stream* stream) const override {} + Status ThenExecute(Device* device, se::Stream* stream, + std::function<void()> func) override; + private: int stream_id_; // The default primary stream to use for this context. diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index b184fd91e1..794250a2c1 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -89,6 +89,15 @@ class DeviceContext : public core::RefCounted { Tensor* cpu_tensor, StatusCallback done) { done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); } + + // If possible, wait for all events on *stream to complete then execute func. + // A non-OK Status is returned otherwise. The stream argument should be the + // one provided by GpuDeviceInfo. This function is not applicable to devices + // that don't provide such a value. + virtual Status ThenExecute(Device* device, stream_executor::Stream* stream, + std::function<void()> func) { + return errors::Internal("ThenExecute not supported by device"); + } }; // map[i] is the DeviceContext* for the node with id i, if i < map.size(). diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 12e3e46f65..f543dca49e 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -45,6 +45,8 @@ VirtualCluster::VirtualCluster(const DeviceSet* device_set) for (const auto& device : device_set_->devices()) { DeviceProperties props = GetDeviceInfo(device->parsed_name()); if (props.type() == "UNKNOWN") continue; + auto attrs = device->attributes(); + props.set_memory_size(attrs.memory_limit()); devices_[device->name()] = props; } } diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index a60e3c7a9f..0690640ffa 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -18,6 +18,7 @@ limitations under the License. #include <limits> #include <unordered_map> +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/types.h" #include "tensorflow/core/grappler/costs/graph_properties.h" diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc index a5736d40b1..b01aca610a 100644 --- a/tensorflow/core/grappler/costs/graph_memory.cc +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 231c7c63be..6710ff9df3 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -804,8 +805,9 @@ class SymbolicShapeRefiner { CHECK_NOTNULL(function_library_.Find(function_node->op())); GrapplerFunctionItem grappler_function_item; - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( - *function_def, function_library_, &grappler_function_item)); + TF_RETURN_IF_ERROR( + MakeGrapplerFunctionItem(*function_def, function_library_, + graph_def_version_, &grappler_function_item)); if (grappler_function_item.inputs().size() > function_node->input_size()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 5acfb56b05..8938b7c32e 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -18,8 +18,10 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -783,6 +785,46 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { EXPECT_EQ("float: [128,256]", PropToString(prop)); } +TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { + // Create graph with a function that takes a scalar value so that we use + // Placeholder with scalar as for input to the function shape inference. + // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of + // the input; all tensors are scalars. + FunctionDefLibrary library; + *library.add_function() = FunctionDefHelper::Create( + "MyFunc", // Name + {"x: float"}, // Inputs + {"out: float"}, // Outputs + {}, // Attrs + {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes + {{"out", "a:output:0"}}); // Returns + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(library)); + Output placeholder = + ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT, + ops::Placeholder::Shape(TensorShape({}))); + Output identity = ops::Identity(s.WithOpName("Identity"), placeholder); + auto _identity = tensorflow::ops::AsNodeOut(s, identity); + auto builder = + tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry()); + tensorflow::Node* func_op; + TF_CHECK_OK(builder.Input(_identity).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Tensorflow version < 21 infers output shape of Placeholder with empty shape + // as unknown, instead of scalar. + EXPECT_GT(item.graph.versions().producer(), 21); + + // MyFunc output shouldn't be unknown rank. + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); + EXPECT_FALSE(out_prop0.shape().unknown_rank()); +} + TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) { // Test graph produced in python using: /* diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 0341d7f8e1..71f4d9fd05 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/clusters/utils.h" diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 9e579098ef..998bd59dce 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 6e3ebdee12..037a823096 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -880,10 +880,15 @@ Costs VirtualScheduler::Summary() const { // Print per device summary VLOG(1) << "Devices:"; Costs critical_path_costs = Costs::ZeroCosts(); + std::vector<string> device_names; + device_names.reserve(device_.size()); + for (auto& it : device_) { + device_names.push_back(it.first); + } + std::sort(device_names.begin(), device_names.end()); - for (const auto& device : device_) { - const auto& name = device.first; - const auto& state = device.second; + for (const auto& name : device_names) { + const auto& state = device_.at(name); std::map<string, int64> op_to_memory; // First profile only persistent memory usage. diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index b1373d8317..02a379fca8 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/virtual_scheduler.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 288587ce9b..029515ad3c 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variable.pb.h" diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index caaa5ac8db..a8af169e28 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -827,11 +827,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:op_types", - "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/costs:graph_properties", ], ) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 889445bbd6..4fb2fe6883 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/costs/graph_properties.h" diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b9765b9292..5bf45af6b3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -3047,6 +3047,39 @@ TEST_F(ConstantFoldingTest, TensorArraySize) { test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]); } +TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) { + // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ. + // Make sure constant folding behaves the same way as TensorFlow. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = + ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1}); + Output b = ops::Const(s.WithOpName("b"), 0.1f, {1}); + Output c = ops::Mul(s.WithOpName("c"), a, b); + + GrapplerItem item; + item.fetch.push_back("c"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(1, output.node_size()); + + const NodeDef& node_d = output.node(0); + EXPECT_EQ("c", node_d.name()); + EXPECT_EQ("Const", node_d.op()); + + std::vector<string> fetch = {"c"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 23f35050f2..92551a0459 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc index 00ad7494f4..79d9ea1608 100644 --- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc +++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/evaluation_utils.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/denormal.h" diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h index 8414b5b8ca..c9dfb6dc0b 100644 --- a/tensorflow/core/grappler/optimizers/evaluation_utils.h +++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" namespace Eigen { diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 645e4c2087..56364f0095 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -453,6 +453,7 @@ Status InitializeFunctionSpecializationSignature( } Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, + const int graph_def_version, FunctionOptimizerContext* ctx, GraphDef* optimized_graph) { VLOG(2) << "Specialize function instantiation: " @@ -492,7 +493,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, // Make a GrapplerFunctionItem and convert it back to FunctionDef after // pushing all constant inputs into the function body. GrapplerFunctionItem item; - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, + graph_def_version, &item)); // Push const inputs into the function body, and keep track of their control // dependencies. @@ -576,15 +578,15 @@ NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node, Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, const FunctionOptimizerContext& ctx, - GraphDef* optimized_graph) { + const int graph_def_version, GraphDef* optimized_graph) { VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node); const std::unordered_map<string, AttrValue> func_attr( func_node.attr().begin(), func_node.attr().end()); GrapplerFunctionItem item; - Status item_status = - MakeGrapplerFunctionItem(func, func_attr, ctx.function_library(), &item); + Status item_status = MakeGrapplerFunctionItem( + func, func_attr, ctx.function_library(), graph_def_version, &item); if (!item_status.ok()) { return errors::InvalidArgument("Failed to inline function ", func_node.op(), @@ -645,7 +647,8 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, if (func_body_node_func != nullptr) { // Recursively inline function calls. TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func, - ctx, optimized_graph)); + ctx, graph_def_version, + optimized_graph)); } else { // Annotate the node with the function attributes. for (const auto& attr : func.attr()) { @@ -824,7 +827,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (inline_func && ctx.IsInlinedFunction(func_name)) { // Inline function body into the optimized graph} TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( - InlineFunction(node, *func, ctx, optimized_graph)); + InlineFunction(node, *func, ctx, item.graph.versions().producer(), + optimized_graph)); continue; } @@ -837,7 +841,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // TODO(ezhulenev): Specialize function call if input has a known shape. // Specialize function body for its instantiation attributes and inputs. TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( - SpecializeFunction(node, *func, &ctx, optimized_graph)); + SpecializeFunction(node, *func, item.graph.versions().producer(), + &ctx, optimized_graph)); continue; } } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 1be5f8dcc2..91794cefe5 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_memory.h" diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index e778b7879d..5fd34efeb1 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -361,7 +361,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Make a GrapplerItem from a FunctionDef. GrapplerFunctionItem func_item; - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item)); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( + func, flib, item.graph.versions().producer(), &func_item)); // Optimize function body graph. GraphDef optimized_func_graph; diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc index 89847f83d4..b033cff8e6 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/testlib.h" diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 26c54df56b..caa0b7b0cb 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/shape_optimizer.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/graph_view.h" diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 462b752316..a2c363ea6e 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -307,8 +308,8 @@ GrapplerFunctionItem::GrapplerFunctionItem( const AttrValueMap& func_attr, const std::vector<InputArgExpansion>& input_arg_expansions, const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, bool is_stateful, - GraphDef&& function_body) + const std::vector<string>& keep_nodes, const int graph_def_version, + bool is_stateful, GraphDef&& function_body) : description_(description), func_attr_(func_attr), input_arg_expansions_(input_arg_expansions), @@ -318,6 +319,7 @@ GrapplerFunctionItem::GrapplerFunctionItem( keep_ops = keep_nodes; // Swap the graph body. graph.Swap(&function_body); + graph.mutable_versions()->set_producer(graph_def_version); // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { for (const string& placeholder : input_arg.placeholders) { @@ -472,6 +474,7 @@ Status InstantiationBodyParameters( Status MakeGrapplerFunctionItem(const FunctionDef& func, const AttrValueMap& func_instantiation_attr, const FunctionLibraryDefinition& flib, + const int graph_def_version, GrapplerFunctionItem* item) { const OpDef& signature = func.signature(); @@ -595,14 +598,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, *item = GrapplerFunctionItem( /*func_name=*/signature.name(), /*description=*/signature.description(), /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()), - inputs, outputs, keep_nodes, is_stateful, std::move(function_body)); + inputs, outputs, keep_nodes, graph_def_version, is_stateful, + std::move(function_body)); return Status::OK(); } Status MakeGrapplerFunctionItem(const FunctionDef& func, const FunctionLibraryDefinition& flib, + const int graph_def_version, GrapplerFunctionItem* item) { - return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, item); + return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, graph_def_version, + item); } // Register GrapplerFunctionItem input arg expansion and function body outputs diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 9f607dc2ee..61588ceb83 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -141,8 +141,8 @@ class GrapplerFunctionItem : public GrapplerItem { const AttrValueMap& func_attr, const std::vector<InputArgExpansion>& input_arg_expansions, const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, bool is_stateful, - GraphDef&& function_body); + const std::vector<string>& keep_nodes, const int versions, + bool is_stateful, GraphDef&& function_body); const string& description() const; @@ -222,6 +222,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, Status MakeGrapplerFunctionItem(const FunctionDef& func, const AttrValueMap& func_instantiation_attr, const FunctionLibraryDefinition& flib, + const int graph_def_version, GrapplerFunctionItem* item); // Make a GrapplerFunction item from the function definition. Function must be @@ -231,6 +232,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, // without specializing it to it's instantiation attributes (at least types)? Status MakeGrapplerFunctionItem(const FunctionDef& func, const FunctionLibraryDefinition& flib, + const int graph_def_version, GrapplerFunctionItem* item); // Make a FunctionDef from the GrapplerFunctionItem. Use function library diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index b2d059e0ac..b51f2781b8 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace grappler { @@ -239,7 +240,8 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ("XTimesTwo", item.id); EXPECT_EQ(4, item.function_body().node_size()); @@ -314,7 +316,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ("SubGrad", item.id); EXPECT_EQ(12, item.function_body().node_size()); @@ -395,7 +398,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { func_attr["T"].set_type(DT_FLOAT); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); int count = 0; for (const NodeDef &node : item.function_body().node()) { @@ -456,7 +460,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ(1, item.output_size()); EXPECT_EQ("Exp", item.output(0).output_tensors[0]); @@ -499,7 +504,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ("ForwardInputs", item.id); EXPECT_EQ(5, item.function_body().node_size()); @@ -545,7 +551,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ(0, item.input_size()); EXPECT_EQ(1, item.output_size()); @@ -584,7 +591,8 @@ TEST_F(FunctionsTest, MakeFunctionDef) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); FunctionDef specialized; TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized)); @@ -622,7 +630,8 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) { FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); EXPECT_EQ(2, item.input_size()); EXPECT_EQ(1, item.output_size()); @@ -713,7 +722,8 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) { FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); // Replace function body with identity function item.SwapFunctionBody(std::move(id_func_body)); @@ -754,7 +764,8 @@ TEST_F(FunctionsTest, FunctionDefGrapplerFunctionItemRoundTrip) { GrapplerFunctionItem item; std::unordered_map<string, AttrValue> func_attr; func_attr["T"].set_type(DT_INT32); - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, + TF_GRAPH_DEF_VERSION, &item)); FunctionDef func2; TF_EXPECT_OK(MakeFunctionDef(item, flib, &func2)); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 82ff2a365d..7716043055 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -237,6 +237,7 @@ cc_library( srcs = ["parse_example_dataset_op.cc"], deps = [ ":parallel_map_iterator", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", ], ) diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index cc5007ee92..6a0522e4f3 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <deque> +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/util/example_proto_fast_parsing.h" @@ -166,8 +167,6 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape>& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), - device_threadpool_( - ctx->device()->tensorflow_cpu_worker_threads()->workers), dense_defaults_(std::move(dense_defaults)), sparse_keys_(std::move(sparse_keys)), dense_keys_(std::move(dense_keys)), @@ -190,6 +189,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor> input_element, std::vector<Tensor>* result, StatusCallback done) { (*ctx->runner())([this, ctx, input_element, result, done]() { + thread::ThreadPool* device_threadpool = + ctx->lib()->device()->tensorflow_cpu_worker_threads()->workers; std::vector<string> slice_vec; for (Tensor t : input_element) { auto serialized_t = t.flat<string>(); @@ -205,7 +206,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { config.collect_feature_stats = true; } example::Result example_result; - Status s = FastParseExample(config, slice_vec, {}, device_threadpool_, + Status s = FastParseExample(config, slice_vec, {}, device_threadpool, &example_result); if (s.ok()) { (*result).resize(key_to_output_index_.size()); @@ -339,7 +340,6 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { private: const DatasetBase* const input_; - thread::ThreadPool* const device_threadpool_; const std::vector<Tensor> dense_defaults_; const std::vector<string> sparse_keys_; const std::vector<string> dense_keys_; diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc index 9ec83b867f..aa70ee06f5 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc @@ -196,6 +196,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> { } output(sample) = z * stddev + mean; sample++; + if (sample >= limit_sample) { + break; + } numIterations = 0; } else { numIterations++; diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc index 8c28620ff9..fface033cb 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.cc +++ b/tensorflow/core/lib/monitoring/collection_registry.cc @@ -38,15 +38,15 @@ void Collector::CollectMetricDescriptor( mutex_lock l(mu_); return collected_metrics_->metric_descriptor_map .insert(std::make_pair( - std::string(metric_def->name()), + string(metric_def->name()), std::unique_ptr<MetricDescriptor>(new MetricDescriptor()))) .first->second.get(); }(); - metric_descriptor->name = std::string(metric_def->name()); - metric_descriptor->description = std::string(metric_def->description()); + metric_descriptor->name = string(metric_def->name()); + metric_descriptor->description = string(metric_def->description()); for (const StringPiece label_name : metric_def->label_descriptions()) { - metric_descriptor->label_names.push_back(std::string(label_name)); + metric_descriptor->label_names.emplace_back(label_name); } metric_descriptor->metric_kind = metric_def->kind(); diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 20f0444f8b..c204d52cfe 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -72,7 +72,7 @@ class MetricCollector { registration_time_millis_(registration_time_millis), collector_(collector), point_set_(point_set) { - point_set_->metric_name = std::string(metric_def->name()); + point_set_->metric_name = string(metric_def->name()); } const MetricDef<metric_kind, Value, NumLabels>* const metric_def_; @@ -261,7 +261,7 @@ class Collector { auto* const point_set = [&]() { mutex_lock l(mu_); return collected_metrics_->point_set_map - .insert(std::make_pair(std::string(metric_def->name()), + .insert(std::make_pair(string(metric_def->name()), std::unique_ptr<PointSet>(new PointSet()))) .first->second.get(); }(); diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index 6f94685665..756e5c2af8 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -98,8 +98,8 @@ class AbstractMetricDef { const std::vector<string>& label_descriptions) : kind_(kind), value_type_(value_type), - name_(std::string(name)), - description_(std::string(description)), + name_(name), + description_(description), label_descriptions_(std::vector<string>(label_descriptions.begin(), label_descriptions.end())) {} diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index e0a5281d68..959290ba8c 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -140,11 +140,11 @@ inline bool ProtoParseNumeric(StringPiece s, uint64* value) { } inline bool ProtoParseNumeric(StringPiece s, float* value) { - return safe_strtof(std::string(s).c_str(), value); + return safe_strtof(s, value); } inline bool ProtoParseNumeric(StringPiece s, double* value) { - return safe_strtod(std::string(s).c_str(), value); + return safe_strtod(s, value); } // Convert strings to number of type T. diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index cab8f81585..3aba5ec80e 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -332,7 +332,7 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, bool replace_all) { // TODO(jlebar): We could avoid having to shift data around in the string if // we had a StringPiece::find() overload that searched for a StringPiece. - string res = std::string(s); + string res(s); size_t pos = 0; while ((pos = res.find(oldsub.data(), pos, oldsub.size())) != string::npos) { res.replace(pos, oldsub.size(), newsub.data(), newsub.size()); @@ -448,8 +448,7 @@ bool SplitAndParseAsFloats(StringPiece text, char delim, std::vector<float>* result) { return SplitAndParseAsInts<float>(text, delim, [](StringPiece str, float* value) { - return strings::safe_strtof( - std::string(str).c_str(), value); + return strings::safe_strtof(str, value); }, result); } diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index 58e87fcb9e..9f52cf29fc 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -205,7 +205,7 @@ std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p) { if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { StringPiece token(text.data() + token_start, i - token_start); if (p(token)) { - result.push_back(std::string(token)); + result.emplace_back(token); } token_start = i + 1; } diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index 47c59d435b..afc4201e53 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -92,7 +92,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {} Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) { StringPiece scheme, host, path; io::ParseURI(fname, &scheme, &host, &path); - FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme)); + FileSystem* file_system = file_system_registry_->Lookup(string(scheme)); if (!file_system) { if (scheme.empty()) { scheme = "[local]"; @@ -166,7 +166,7 @@ bool Env::FilesExist(const std::vector<string>& files, for (const auto& file : files) { StringPiece scheme, host, path; io::ParseURI(file, &scheme, &host, &path); - files_per_fs[std::string(scheme)].push_back(file); + files_per_fs[string(scheme)].push_back(file); } std::unordered_map<string, Status> per_file_status; diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 922773684b..3ab542a5d8 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -158,7 +158,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) { std::reverse(sub_dirs.begin(), sub_dirs.end()); // Now create the directories. - string built_path = std::string(remaining_dir); + string built_path(remaining_dir); for (const StringPiece sub_dir : sub_dirs) { built_path = io::JoinPath(built_path, sub_dir); Status status = CreateDir(io::CreateURI(scheme, host, built_path)); diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc index 0ba0e6304f..342cf28e38 100644 --- a/tensorflow/core/platform/file_system_helper.cc +++ b/tensorflow/core/platform/file_system_helper.cc @@ -59,7 +59,7 @@ Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); string eval_pattern = pattern; std::vector<string> all_files; - string dir = std::string(io::Dirname(fixed_prefix)); + string dir(io::Dirname(fixed_prefix)); // If dir is empty then we need to fix up fixed_prefix and eval_pattern to // include . as the top level directory. if (dir.empty()) { diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc index c0a16c95f9..a637d42a92 100644 --- a/tensorflow/core/platform/file_system_test.cc +++ b/tensorflow/core/platform/file_system_test.cc @@ -125,7 +125,7 @@ class InterPlanetaryFileSystem : public NullFileSystem { ASSERT_EQ(scheme, "ipfs"); ASSERT_EQ(host, "solarsystem"); str_util::ConsumePrefix(&path, "/"); - *parsed_path = std::string(path); + *parsed_path = string(path); } std::map<string, std::set<string>> celestial_bodies_ = { diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index b281acb2b0..55f1e30880 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -32,7 +32,7 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, if (str_util::ConsumePrefix(&arg, "--") && str_util::ConsumePrefix(&arg, flag) && str_util::ConsumePrefix(&arg, "=")) { - *value_parsing_ok = hook(std::string(arg)); + *value_parsing_ok = hook(string(arg)); return true; } diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index aee647a1b3..5e2aeb7830 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -259,6 +259,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( } else { max_coeff = raw_input.maxCoeff(); } + + // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))). + float logsumexp = 0.0; + for (int j = 0; j < raw_input.size(); ++j) { + logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff); + } + logsumexp = Eigen::numext::log(logsumexp); + // Final normalization offset to get correct log probabilities. + float norm_offset = max_coeff + logsumexp; + const float label_selection_input_min = (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) : -std::numeric_limits<float>::infinity(); @@ -290,10 +300,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( beam_scorer_->GetStateExpansionScore(b->state, previous)); } // Plabel(l=abc @ t=6) *= P(c @ 6) - b->newp.label += raw_input(b->label) - max_coeff; + b->newp.label += raw_input(b->label) - norm_offset; } // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) - b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset; // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) b->newp.total = LogSumExp(b->newp.blank, b->newp.label); @@ -328,6 +338,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( const float logit = top_k ? top_k_logits[ind] : raw_input(ind); // Perform label selection: if input for this label looks very // unpromising, never evaluate it with a scorer. + // We may compare logits instead of log probabilities, + // since the difference is the same in both cases. if (logit < label_selection_input_min) { continue; } @@ -341,7 +353,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; - c.newp.label = logit - max_coeff + + c.newp.label = logit - norm_offset + beam_scorer_->GetStateExpansionScore(c.state, previous); // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) c.newp.total = c.newp.label; diff --git a/tensorflow/core/util/env_var.cc b/tensorflow/core/util/env_var.cc index 8d43bcc927..2604a5d66a 100644 --- a/tensorflow/core/util/env_var.cc +++ b/tensorflow/core/util/env_var.cc @@ -28,7 +28,7 @@ namespace tensorflow { Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, bool* value) { *value = default_val; - const char* tf_env_var_val = getenv(std::string(env_var_name).c_str()); + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); if (tf_env_var_val == nullptr) { return Status::OK(); } @@ -48,7 +48,7 @@ Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val, int64* value) { *value = default_val; - const char* tf_env_var_val = getenv(std::string(env_var_name).c_str()); + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); if (tf_env_var_val == nullptr) { return Status::OK(); } @@ -62,11 +62,11 @@ Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val, Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val, string* value) { - const char* tf_env_var_val = getenv(std::string(env_var_name).c_str()); + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); if (tf_env_var_val != nullptr) { *value = tf_env_var_val; } else { - *value = std::string(default_val); + *value = string(default_val); } return Status::OK(); } diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 1fec0010a1..a38cd1d09f 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -353,7 +353,7 @@ bool TestFastParse(const string& serialized, Example* example) { // I.e. last entry in the map overwrites all the previous ones. parsed::FeatureMapEntry& name_and_feature = parsed_example[parsed_example_size - i - 1]; - string name = std::string(name_and_feature.first); + string name(name_and_feature.first); if ((*features.mutable_feature()).count(name) > 0) continue; auto& value = (*features.mutable_feature())[name]; diff --git a/tensorflow/docs_src/guide/premade_estimators.md b/tensorflow/docs_src/guide/premade_estimators.md index a1703058c3..9b64d51b98 100644 --- a/tensorflow/docs_src/guide/premade_estimators.md +++ b/tensorflow/docs_src/guide/premade_estimators.md @@ -366,6 +366,8 @@ Running this code yields the following output (or something similar): Test set accuracy: 0.967 ``` +The `eval_result` dictionary also contains the `average_loss` (mean loss per sample), the `loss` (mean loss per mini-batch) and the value of the estimator's `global_step` (the number of training iterations it underwent). + ### Making predictions (inferring) from the trained model We now have a trained model that produces good evaluation results. diff --git a/tensorflow/docs_src/guide/saved_model.md b/tensorflow/docs_src/guide/saved_model.md index 6c967fd882..33ab891861 100644 --- a/tensorflow/docs_src/guide/saved_model.md +++ b/tensorflow/docs_src/guide/saved_model.md @@ -2,7 +2,7 @@ The `tf.train.Saver` class provides methods to save and restore models. The `tf.saved_model.simple_save` function is an easy way to build a -`tf.saved_model` suitable for serving. [Estimators](./estimators) +`tf.saved_model` suitable for serving. [Estimators](../guide/estimators.md) automatically save and restore variables in the `model_dir`. ## Save and restore variables diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc index babf55cd5f..fb93bb6d8e 100644 --- a/tensorflow/js/ops/ts_op_gen.cc +++ b/tensorflow/js/ops/ts_op_gen.cc @@ -38,6 +38,15 @@ struct ArgDefs { const ApiDef::Arg& api_def_arg; }; +// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op. +struct OpAttrs { + OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr) + : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {} + + const OpDef::AttrDef& op_def_attr; + const ApiDef::Attr& api_def_attr; +}; + // Helper class to generate TypeScript code for a given OpDef: class GenTypeScriptOp { public: @@ -49,8 +58,12 @@ class GenTypeScriptOp { private: void ProcessArgs(); + void ProcessAttrs(); + void AddAttrForArg(const string& attr, int arg_index); + string InputForAttr(const OpDef::AttrDef& op_def_attr); void AddMethodSignature(); + void AddOpAttrs(); void AddMethodReturnAndClose(); const OpDef& op_def_; @@ -62,6 +75,13 @@ class GenTypeScriptOp { // Holds in-order vector of Op inputs: std::vector<ArgDefs> input_op_args_; + // Holds in-order vector of Op attributes: + std::vector<OpAttrs> op_attrs_; + + // Stores attributes-to-arguments by name: + typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap; + AttrArgIdxMap attr_arg_idx_map_; + // Holds number of outputs: int num_outputs_; }; @@ -73,9 +93,11 @@ GenTypeScriptOp::~GenTypeScriptOp() {} string GenTypeScriptOp::Code() { ProcessArgs(); + ProcessAttrs(); // Generate exported function for Op: AddMethodSignature(); + AddOpAttrs(); AddMethodReturnAndClose(); strings::StrAppend(&result_, "\n"); @@ -96,12 +118,52 @@ void GenTypeScriptOp::ProcessArgs() { << api_def_.arg_order(i); continue; } + + // Map attr names to arg indexes: + if (!op_def_arg->type_attr().empty()) { + AddAttrForArg(op_def_arg->type_attr(), i); + } else if (!op_def_arg->type_list_attr().empty()) { + AddAttrForArg(op_def_arg->type_list_attr(), i); + } + if (!op_def_arg->number_attr().empty()) { + AddAttrForArg(op_def_arg->number_attr(), i); + } + input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg)); } num_outputs_ = api_def_.out_arg_size(); } +void GenTypeScriptOp::ProcessAttrs() { + for (int i = 0; i < op_def_.attr_size(); i++) { + op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i))); + } +} + +void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) { + // Keep track of attributes-to-arguments by name. These will be used for + // construction Op attributes that require information about the inputs. + auto iter = attr_arg_idx_map_.find(attr); + if (iter == attr_arg_idx_map_.end()) { + attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index})); + } else { + iter->second.push_back(arg_index); + } +} + +string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) { + string inputs; + auto arg_list = attr_arg_idx_map_.find(op_def_attr.name()); + if (arg_list != attr_arg_idx_map_.end()) { + for (auto iter = arg_list->second.begin(); iter != arg_list->second.end(); + ++iter) { + strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name()); + } + } + return inputs; +} + void GenTypeScriptOp::AddMethodSignature() { strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(), "("); @@ -131,6 +193,35 @@ void GenTypeScriptOp::AddMethodSignature() { } } +void GenTypeScriptOp::AddOpAttrs() { + strings::StrAppend(&result_, " const opAttrs = [\n"); + + bool is_first = true; + for (auto& attr : op_attrs_) { + if (is_first) { + is_first = false; + } else { + strings::StrAppend(&result_, ",\n"); + } + + // Append 4 spaces to start: + strings::StrAppend(&result_, " "); + + if (attr.op_def_attr.type() == "type") { + // Type OpAttributes can be generated from a helper function: + strings::StrAppend(&result_, "createTensorsTypeOpAttr('", + attr.op_def_attr.name(), "', ", + InputForAttr(attr.op_def_attr), ")"); + } else if (attr.op_def_attr.type() == "int") { + strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', "); + strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, "); + strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr), + ".length}"); + } + } + strings::StrAppend(&result_, "\n ];\n"); +} + void GenTypeScriptOp::AddMethodReturnAndClose() { strings::StrAppend(&result_, " return null;\n}\n"); } @@ -162,7 +253,7 @@ void StartFile(WritableFile* ts_file) { // This file is MACHINE GENERATED! Do not edit import * as tfc from '@tensorflow/tfjs-core'; -import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; +import {createTensorsTypeOpAttr, nodeBackend} from './op_utils'; )header"; diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc index 9a85c021b0..03241689b5 100644 --- a/tensorflow/js/ops/ts_op_gen_test.cc +++ b/tensorflow/js/ops/ts_op_gen_test.cc @@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) { << "'" << s << "' does not contain '" << expected << "'"; } -// TODO(kreeger): Add multiple outputs here? constexpr char kBaseOpDef[] = R"( op { name: "Foo" @@ -79,50 +78,15 @@ op { summary: "Summary for op Foo." description: "Description for op Foo." } -op { - name: "DeprecatedFoo" - input_arg { - name: "input" - description: "Description for input." - type: DT_FLOAT - } - output_arg { - name: "output" - description: "Description for output." - type: DT_FLOAT - } - deprecation { - explanation: "Deprecated." - } -} -op { - name: "MultiOutputFoo" - input_arg { - name: "input" - description: "Description for input." - type: DT_FLOAT - } - output_arg { - name: "output1" - description: "Description for output 1." - type: DT_FLOAT - } - output_arg { - name: "output2" - description: "Description for output 2." - type: DT_FLOAT - } - summary: "Summary for op MultiOutputFoo." - description: "Description for op MultiOutputFoo." -} )"; // Generate TypeScript code -// @param api_def_str TODO doc me. -void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { +void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str, + string* ts_file_text) { Env* env = Env::Default(); OpList op_defs; - protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); + protobuf::TextFormat::ParseFromString( + op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs); ApiDefMap api_def_map(op_defs); if (!api_def_str.empty()) { @@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { TEST(TsOpGenTest, TestImports) { string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText("", "", &ts_file_text); const string expected = R"( import * as tfc from '@tensorflow/tfjs-core'; -import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; +import {createTensorsTypeOpAttr, nodeBackend} from './op_utils'; )"; ExpectContainsStr(ts_file_text, expected); } @@ -160,12 +124,10 @@ op { )"; string ts_file_text; - GenerateTsOpFileText(api_def, &ts_file_text); + GenerateTsOpFileText("", api_def, &ts_file_text); const string expected = R"( export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { - return null; -} )"; ExpectContainsStr(ts_file_text, expected); } @@ -179,34 +141,106 @@ op { )"; string ts_file_text; - GenerateTsOpFileText(api_def, &ts_file_text); + GenerateTsOpFileText("", api_def, &ts_file_text); const string expected = R"( export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { - return null; -} )"; ExpectDoesNotContainStr(ts_file_text, expected); } TEST(TsOpGenTest, SkipDeprecated) { + const string op_def = R"( +op { + name: "DeprecatedFoo" + input_arg { + name: "input" + type_attr: "T" + description: "Description for input." + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for input" + allowed_values { + list { + type: DT_FLOAT + } + } + } + deprecation { + explanation: "Deprecated." + } +} +)"; + string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText(op_def, "", &ts_file_text); ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo"); } TEST(TsOpGenTest, MultiOutput) { + const string op_def = R"( +op { + name: "MultiOutputFoo" + input_arg { + name: "input" + description: "Description for input." + type_attr: "T" + } + output_arg { + name: "output1" + description: "Description for output 1." + type: DT_FLOAT + } + output_arg { + name: "output2" + description: "Description for output 2." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for input" + allowed_values { + list { + type: DT_FLOAT + } + } + } + summary: "Summary for op MultiOutputFoo." + description: "Description for op MultiOutputFoo." +} +)"; + string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText(op_def, "", &ts_file_text); const string expected = R"( export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] { - return null; -} )"; ExpectContainsStr(ts_file_text, expected); } +TEST(TsOpGenTest, OpAttrs) { + string ts_file_text; + GenerateTsOpFileText("", "", &ts_file_text); + + const string expectedFooAttrs = R"( + const opAttrs = [ + createTensorsTypeOpAttr('T', images), + {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length} + ]; +)"; + + ExpectContainsStr(ts_file_text, expectedFooAttrs); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e1d3422730..40f98474b5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -723,7 +723,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", - ":cond_v2_impl", ":dtypes", ":framework_ops", ":graph_to_function_def", @@ -2620,8 +2619,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":constant_op", + ":dtypes", ":framework_test_lib", ":sparse_ops", + ":sparse_tensor", ], ) @@ -3245,7 +3246,6 @@ py_library( ), srcs_version = "PY2AND3", deps = [ - "saver", ":array_ops", ":array_ops_gen", ":checkpoint_management", @@ -3269,6 +3269,7 @@ py_library( ":random_ops", ":resource_variable_ops", ":resources", + ":saver", ":sdca_ops", ":session", ":sparse_ops", diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 6642a5bfb1..e0826a7945 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 23) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 24) @tf_export("compat.forward_compatible") diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 8a4ac6aaef..55d2709845 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -576,7 +576,6 @@ py_test( srcs_version = "PY2AND3", tags = [ "no_windows", - "nomac", "oss_serial", ], deps = [ @@ -1047,7 +1046,6 @@ cuda_py_test( tags = [ "no_oss", # Incompatible with bazel_pip. "no_windows", - "nomac", # TODO(cais): Install of futures and grpcio on all macs. "notsan", ], ) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index bdabbf4ea3..6f48d38b58 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -237,6 +237,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":graph_only_ops", + "//tensorflow/python:cond_v2_impl", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dba9779488..3171ef9d62 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import functools +import sys import threading import numpy as np @@ -38,6 +39,7 @@ from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl @@ -49,6 +51,10 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +# This is to avoid a circular dependency with cond_v2_impl +# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl). +cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access + def create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" @@ -113,10 +119,6 @@ class CapturingGraph(ops.Graph): # for resource tensors. self._last_op_using_resource_tensor = {} - # TODO(apassos) remove once the C API is used by default. - def _use_c_api_hack(self): - return True - def clear_resource_control_flow_state(self): self._last_op_using_resource_tensor = {} @@ -203,6 +205,8 @@ class FuncGraph(CapturingGraph): by this function. The Tensors in this structure are the same as those of self.outputs. Note that this structure might contain Python `None`s. variables: Variables that should be watched during function execution. + outer_graph: The graph this function is defined in. May be another FuncGraph + or the global default Graph. seed: The graph-level random seed. """ @@ -222,8 +226,9 @@ class FuncGraph(CapturingGraph): self.outputs = [] self.structured_outputs = None self.variables = [] + self.outer_graph = ops.get_default_graph() - graph = ops.get_default_graph() + graph = self.outer_graph if context.executing_eagerly(): self.seed = context.global_seed() @@ -259,6 +264,16 @@ class FuncGraph(CapturingGraph): return internal_tensor + @property + def external_captures(self): + """External tensors captured by this function.""" + return list(self.captures.keys()) + + @property + def internal_captures(self): + """Placeholders in this function corresponding captured tensors.""" + return list(self.captures.values()) + def _forward_name(n): """The name of a generated forward defun named n.""" @@ -695,7 +710,7 @@ def _get_defun_inputs_from_args(args): return nest.pack_sequence_as(args, function_inputs) -def _func_graph_from_py_func(name, python_func, args, kwds, signature=None): +def func_graph_from_py_func(name, python_func, args, kwds, signature=None): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -1069,8 +1084,8 @@ class _PolymorphicFunction(object): if graph_function is None: graph_function = GraphCallable( - _func_graph_from_py_func(self._name, self._python_function, args, - kwds, self._input_signature)) + func_graph_from_py_func(self._name, self._python_function, args, + kwds, self._input_signature)) self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) self._arguments_to_functions[cache_key] = graph_function @@ -1469,8 +1484,7 @@ def make_defun_op(func, *args, **kwds): and which can be called directly the way a `@defun` wrapped function can. """ - return GraphCallable( - _func_graph_from_py_func(func.__name__, func, args, kwds)) + return GraphCallable(func_graph_from_py_func(func.__name__, func, args, kwds)) class AutomaticControlDependencies(object): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 3e9bb91d54..4f23b3c4da 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -358,6 +358,47 @@ class FunctionTest(test.TestCase): self.assertEqual(3.0, float(test_assign_add())) + @test_util.run_in_graph_and_eager_modes + def testTensorInitializationInFunctionRaisesError(self): + error_msg = ('Tensor-typed variable initializers must either be ' + 'wrapped in an init_scope or callable.*') + + @function.defun + def tensor_init(): + with self.assertRaisesRegexp(ValueError, error_msg): + resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) + + tensor_init() + + @test_util.run_in_graph_and_eager_modes + def testCallableTensorInitializationInFunction(self): + + @function.defun + def tensor_init(): + v = resource_variable_ops.ResourceVariable( + lambda: constant_op.constant(2.0)) + return v.read_value() + + value = tensor_init() + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(value), 2.0) + + @test_util.run_in_graph_and_eager_modes + def testInitScopeTensorInitializationInFunction(self): + + @function.defun + def tensor_init(): + with ops.init_scope(): + const = constant_op.constant(2.0) + v = resource_variable_ops.ResourceVariable(const) + return v.read_value() + + value = tensor_init() + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(value), 2.0) + def testDefunShapeInferenceWithCapturedResourceVariable(self): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index f7ee42c7f6..bcbd7b7933 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -431,7 +431,11 @@ class Estimator(object): Returns: A dict containing the evaluation metrics specified in `model_fn` keyed by name, as well as an entry `global_step` which contains the value of the - global step for which this evaluation was performed. + global step for which this evaluation was performed. For canned + estimators, the dict contains the `loss` (mean loss per mini-batch) and + the `average_loss` (mean loss per sample). Canned classifiers also return + the `accuracy`. Canned regressors also return the `label/mean` and the + `prediction/mean`. Raises: ValueError: If `steps <= 0`. diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 7723fcce74..55aace5fa9 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -311,13 +311,33 @@ def build_parsing_serving_input_receiver_fn(feature_spec, def _placeholder_from_tensor(t, default_batch_size=None): + """Creates a placeholder that matches the dtype and shape of passed tensor. + + Args: + t: Tensor or EagerTensor + default_batch_size: the number of query examples expected per batch. + Leave unset for variable batch size (recommended). + + Returns: + Placeholder that matches the passed tensor. + """ batch_shape = tensor_shape.TensorShape([default_batch_size]) shape = batch_shape.concatenate(t.get_shape()[1:]) # Reuse the feature tensor's op name (t.op.name) for the placeholder, # excluding the index from the tensor's name (t.name): # t.name = "%s:%d" % (t.op.name, t._value_index) - return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name) + try: + name = t.op.name + except AttributeError: + # In Eager mode, tensors don't have ops or names, and while they do have + # IDs, those are not maintained across runs. The name here is used + # primarily for debugging, and is not critical to the placeholder. + # So, in order to make this Eager-compatible, continue with an empty + # name if none is available. + name = None + + return array_ops.placeholder(dtype=t.dtype, shape=shape, name=name) def _placeholders_from_receiver_tensors_dict(input_vals, diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index e87b88327f..3eed1ab163 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -416,6 +416,7 @@ class ExportTest(test_util.TensorFlowTestCase): tensor_shape.unknown_shape(), v.receiver_tensors["feature_2"].shape) + @test_util.run_in_graph_and_eager_modes def test_build_raw_serving_input_receiver_fn(self): features = {"feature_1": constant_op.constant(["hello"]), "feature_2": constant_op.constant([42])} @@ -434,6 +435,7 @@ class ExportTest(test_util.TensorFlowTestCase): dtypes.int32, serving_input_receiver.receiver_tensors["feature_2"].dtype) + @test_util.run_in_graph_and_eager_modes def test_build_raw_supervised_input_receiver_fn(self): features = {"feature_1": constant_op.constant(["hello"]), "feature_2": constant_op.constant([42])} @@ -454,6 +456,7 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertEqual( dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype) + @test_util.run_in_graph_and_eager_modes def test_build_raw_supervised_input_receiver_fn_raw_tensors(self): features = {"feature_1": constant_op.constant(["hello"]), "feature_2": constant_op.constant([42])} @@ -477,6 +480,7 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertEqual(set(["input", "label"]), set(input_receiver.receiver_tensors.keys())) + @test_util.run_in_graph_and_eager_modes def test_build_raw_supervised_input_receiver_fn_batch_size(self): features = {"feature_1": constant_op.constant(["hello"]), "feature_2": constant_op.constant([42])} @@ -489,6 +493,7 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape) self.assertEqual([10], input_receiver.features["feature_1"].shape) + @test_util.run_in_graph_and_eager_modes def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self): features = {"feature_1": constant_op.constant(["hello"]), "feature_2": constant_op.constant([42])} @@ -497,6 +502,7 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): export.build_raw_supervised_input_receiver_fn(features, labels) + @test_util.run_in_graph_and_eager_modes def test_build_supervised_input_receiver_fn_from_input_fn(self): def dummy_input_fn(): return ({"x": constant_op.constant([[1], [1]]), @@ -514,6 +520,7 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertEqual(set(["x", "y", "label"]), set(input_receiver.receiver_tensors.keys())) + @test_util.run_in_graph_and_eager_modes def test_build_supervised_input_receiver_fn_from_input_fn_args(self): def dummy_input_fn(feature_key="x"): return ({feature_key: constant_op.constant([[1], [1]]), diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 220c3e58ca..12daddb044 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -51,6 +51,7 @@ _DEFAULT_REPLACEABLE_LIST = [ 'device_fn', 'protocol', 'eval_distribute', + 'experimental_distribute', ] _SAVE_CKPT_ERR = ( @@ -331,7 +332,8 @@ class RunConfig(object): train_distribute=None, device_fn=None, protocol=None, - eval_distribute=None): + eval_distribute=None, + experimental_distribute=None): """Constructs a RunConfig. All distributed training related properties `cluster_spec`, `is_chief`, @@ -469,6 +471,9 @@ class RunConfig(object): `tf.contrib.distribute.DistributionStrategy`. If specified, then Estimator will distribute the user's model during evaluation, according to the policy specified by that strategy. + experimental_distribute: an optional + `tf.contrib.distribute.DistributeConfig` object specifying + DistributionStrategy-related configuration. Raises: ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` @@ -508,7 +513,8 @@ class RunConfig(object): train_distribute=train_distribute, device_fn=device_fn, protocol=protocol, - eval_distribute=eval_distribute) + eval_distribute=eval_distribute, + experimental_distribute=experimental_distribute) self._init_distributed_setting_from_environment_var(tf_config) @@ -810,6 +816,7 @@ class RunConfig(object): - `device_fn`, - `protocol`. - `eval_distribute`, + - `experimental_distribute`, In addition, either `save_checkpoints_steps` or `save_checkpoints_secs` can be set (should not be both). diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 9d2babc6e0..9b482237ab 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -2747,6 +2747,62 @@ class FunctionalInputLayerTest(test.TestCase): variables_lib.Variable) self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10]) + def test_fills_cols_to_vars_shared_embedding(self): + # Provide 5 DenseColumn's to input_layer: a NumericColumn, a + # BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The + # EmbeddingColumn creates a Variable and the two SharedEmbeddingColumns + # shared one variable. + price1 = fc.numeric_column('price1') + dense_feature = fc.numeric_column('dense_feature') + dense_feature_bucketized = fc.bucketized_column( + dense_feature, boundaries=[0.]) + some_sparse_column = fc.categorical_column_with_hash_bucket( + 'sparse_feature', hash_bucket_size=5) + some_embedding_column = fc.embedding_column( + some_sparse_column, dimension=10) + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + shared_embedding_a, shared_embedding_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + with ops.Graph().as_default(): + features = { + 'price1': [[3.], [4.]], + 'dense_feature': [[-1.], [4.]], + 'sparse_feature': [['a'], ['x']], + 'aaa': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)), + 'bbb': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 1), + dense_shape=(2, 2)), + } + cols_to_vars = {} + all_cols = [ + price1, dense_feature_bucketized, some_embedding_column, + shared_embedding_a, shared_embedding_b + ] + fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) + self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertEqual(0, len(cols_to_vars[price1])) + self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) + self.assertEqual(1, len(cols_to_vars[some_embedding_column])) + self.assertEqual(1, len(cols_to_vars[shared_embedding_a])) + # This is a bug in the current implementation and should be fixed in the + # new one. + self.assertEqual(0, len(cols_to_vars[shared_embedding_b])) + self.assertIsInstance(cols_to_vars[some_embedding_column][0], + variables_lib.Variable) + self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10]) + self.assertIsInstance(cols_to_vars[shared_embedding_a][0], + variables_lib.Variable) + self.assertAllEqual(cols_to_vars[shared_embedding_a][0].shape, [3, 2]) + def test_fills_cols_to_vars_partitioned_variables(self): price1 = fc.numeric_column('price1') dense_feature = fc.numeric_column('dense_feature') @@ -2772,6 +2828,10 @@ class FunctionalInputLayerTest(test.TestCase): self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(3, len(cols_to_vars[some_embedding_column])) + self.assertEqual( + 'input_from_feature_columns/input_layer/sparse_feature_embedding/' + 'embedding_weights/part_0:0', + cols_to_vars[some_embedding_column][0].name) self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10]) self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10]) self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10]) @@ -5544,20 +5604,6 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertIsNone(partition_info) return embedding_values - # Expected lookup result, using combiner='mean'. - expected_lookups_a = ( - # example 0: - (7., 11.), # ids [2], embedding = [7, 11] - # example 1: - (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] - ) - expected_lookups_b = ( - # example 0: - (1., 2.), # ids [0], embedding = [1, 2] - # example 1: - (0., 0.), # ids [], embedding = [0, 0] - ) - # Build columns. categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index b6bf516286..aa66ed77e9 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -142,6 +142,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.keras.engine import training +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -155,7 +156,6 @@ from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops -from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -164,67 +164,148 @@ from tensorflow.python.training import checkpoint_utils from tensorflow.python.util import nest -def _internal_input_layer(features, - feature_columns, - weight_collections=None, - trainable=True, - cols_to_vars=None, - scope=None): - """See input_layer. `scope` is a name or variable scope to use.""" +class StateManager(object): + """Manages the state associated with FeatureColumns. - feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access - for column in feature_columns: - if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access - raise ValueError( - 'Items of feature_columns must be a _DenseColumn. ' - 'You can wrap a categorical column with an ' - 'embedding_column or indicator_column. Given: {}'.format(column)) - weight_collections = list(weight_collections or []) - if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections: - weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) - if ops.GraphKeys.MODEL_VARIABLES not in weight_collections: - weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) - - # a non-None `scope` can allow for variable reuse, when, e.g., this function - # is wrapped by a `make_template`. - with variable_scope.variable_scope( - scope, default_name='input_layer', values=features.values()): - builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access - output_tensors = [] - ordered_columns = [] - for column in sorted(feature_columns, key=lambda x: x.name): - ordered_columns.append(column) - with variable_scope.variable_scope( - None, default_name=column._var_scope_name): # pylint: disable=protected-access - tensor = column._get_dense_tensor( # pylint: disable=protected-access - builder, - weight_collections=weight_collections, - trainable=trainable) - num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access - batch_size = array_ops.shape(tensor)[0] - output_tensors.append( - array_ops.reshape(tensor, shape=(batch_size, num_elements))) - if cols_to_vars is not None: - # Retrieve any variables created (some _DenseColumn's don't create - # variables, in which case an empty list is returned). - cols_to_vars[column] = ops.get_collection( - ops.GraphKeys.GLOBAL_VARIABLES, - scope=variable_scope.get_variable_scope().name) - _verify_static_batch_size_equality(output_tensors, ordered_columns) - return array_ops.concat(output_tensors, 1) + Some `FeatureColumn`s create variables or resources to assist their + computation. The `StateManager` is responsible for creating and storing these + objects since `FeatureColumn`s are supposed to be stateless configuration + only. + """ + + def create_variable(self, + feature_column, + name, + shape, + dtype=None, + trainable=True, + initializer=None): + """Creates a new variable. + + Args: + feature_column: A `FeatureColumn` object this variable corresponds to. + name: variable name. + shape: variable shape. + dtype: The type of the variable. Defaults to `self.dtype` or `float32`. + trainable: Whether this variable is trainable or not. + initializer: initializer instance (callable). + + Returns: + The created variable. + """ + del feature_column, name, shape, dtype, trainable, initializer + raise NotImplementedError('StateManager.create_variable') + + def add_variable(self, feature_column, var): + """Adds an existing variable to the state. + + Args: + feature_column: A `FeatureColumn` object to associate this variable with. + var: The variable. + """ + del feature_column, var + raise NotImplementedError('StateManager.add_variable') + + def get_variable(self, feature_column, name): + """Returns an existing variable. + + Args: + feature_column: A `FeatureColumn` object this variable corresponds to. + name: variable name. + """ + del feature_column, name + raise NotImplementedError('StateManager.get_var') + + def add_resource(self, feature_column, name, resource): + """Creates a new resource. + + Resources can be things such as tables etc. + + Args: + feature_column: A `FeatureColumn` object this resource corresponds to. + name: Name of the resource. + resource: The resource. + + Returns: + The created resource. + """ + del feature_column, name, resource + raise NotImplementedError('StateManager.add_resource') + def get_resource(self, feature_column, name): + """Returns an already created resource. -def input_layer(features, - feature_columns, - weight_collections=None, - trainable=True, - cols_to_vars=None): - """Returns a dense `Tensor` as input layer based on given `feature_columns`. + Resources can be things such as tables etc. + + Args: + feature_column: A `FeatureColumn` object this variable corresponds to. + name: Name of the resource. + """ + del feature_column, name + raise NotImplementedError('StateManager.get_resource') + + +class _InputLayerStateManager(StateManager): + """Manages the state of InputLayer.""" + + def __init__(self, layer, feature_columns, trainable): + """Creates an _InputLayerStateManager object. + + Args: + layer: The input layer this state manager is associated with. + feature_columns: List of feature columns for the input layer + trainable: Whether by default, variables created are trainable or not. + """ + self._trainable = trainable + self._layer = layer + self._cols_to_vars_map = {} + self._cols_to_names_map = {} + for column in sorted(feature_columns, key=lambda x: x.name): + self._cols_to_vars_map[column] = {} + base_name = column.name + if isinstance(column, SharedEmbeddingColumn): + base_name = column.shared_collection_name + with variable_scope.variable_scope(base_name) as vs: + self._cols_to_names_map[column] = _strip_leading_slashes(vs.name) + + def create_variable(self, + feature_column, + name, + shape, + dtype=None, + trainable=True, + initializer=None): + if name in self._cols_to_vars_map[feature_column]: + raise ValueError('Variable already exists.') + with variable_scope.variable_scope(self._cols_to_names_map[feature_column]): + var = self._layer.add_variable( + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, + trainable=self._trainable and trainable, + # TODO(rohanj): Get rid of this hack once we have a mechanism for + # specifying a default partitioner for an entire layer. In that case, + # the default getter for Layers should work. + getter=variable_scope.get_variable) + self._cols_to_vars_map[feature_column][name] = var + return var + + def get_variable(self, feature_column, name): + if name in self._cols_to_vars_map[feature_column]: + return self._cols_to_vars_map[feature_column][name] + raise ValueError('Variable does not exist.') + + +class FeatureLayer(Layer): + """A layer that produces a dense `Tensor` based on given `feature_columns`. Generally a single example in training data is described with FeatureColumns. At the first layer of the model, this column oriented data should be converted to a single `Tensor`. + This layer can be called multiple times with different features. + Example: ```python @@ -233,105 +314,122 @@ def input_layer(features, categorical_column_with_hash_bucket("keywords", 10K), dimensions=16) columns = [price, keywords_embedded, ...] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - dense_tensor = input_layer(features, columns) + feature_layer = FeatureLayer(columns) + dense_tensor = feature_layer(features) for units in [128, 64, 32]: dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu) - prediction = tf.layers.dense(dense_tensor, 1) - ``` - - Args: - features: A mapping from key to tensors. `_FeatureColumn`s look up via these - keys. For example `numeric_column('price')` will look at 'price' key in - this dict. Values can be a `SparseTensor` or a `Tensor` depends on - corresponding `_FeatureColumn`. - feature_columns: An iterable containing the FeatureColumns to use as inputs - to your model. All items should be instances of classes derived from - `_DenseColumn` such as `numeric_column`, `embedding_column`, - `bucketized_column`, `indicator_column`. If you have categorical features, - you can wrap them with an `embedding_column` or `indicator_column`. - weight_collections: A list of collection names to which the Variable will be - added. Note that variables will also be added to collections - `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. - trainable: If `True` also add the variable to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - cols_to_vars: If not `None`, must be a dictionary that will be filled with a - mapping from `_FeatureColumn` to list of `Variable`s. For example, after - the call, we might have cols_to_vars = - {_EmbeddingColumn( - categorical_column=_HashedCategoricalColumn( - key='sparse_feature', hash_bucket_size=5, dtype=tf.string), - dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10), - <tf.Variable 'some_variable:1' shape=(5, 10)]} - If a column creates no variables, its value will be an empty list. - - Returns: - A `Tensor` which represents input layer of a model. Its shape - is (batch_size, first_layer_dimension) and its dtype is `float32`. - first_layer_dimension is determined based on given `feature_columns`. - - Raises: - ValueError: if an item in `feature_columns` is not a `_DenseColumn`. - """ - return _internal_input_layer(features, feature_columns, weight_collections, - trainable, cols_to_vars) - - -# TODO(akshayka): InputLayer should be a subclass of Layer, and it -# should implement the logic in input_layer using Layer's build-and-call -# paradigm; input_layer should create an instance of InputLayer and -# return the result of invoking its apply method, just as functional layers do. -class InputLayer(object): - """An object-oriented version of `input_layer` that reuses variables.""" + prediction = tf.layers.dense(dense_tensor, 1).""" def __init__(self, feature_columns, - weight_collections=None, trainable=True, - cols_to_vars=None): - """See `input_layer`.""" + name=None, + shared_state_manager=None, + **kwargs): + """Constructs a FeatureLayer. - self._feature_columns = feature_columns - self._weight_collections = weight_collections - self._trainable = trainable - self._cols_to_vars = cols_to_vars - self._input_layer_template = template.make_template( - 'feature_column_input_layer', - _internal_input_layer, - create_scope_now_=True) - self._scope = self._input_layer_template.variable_scope - - def __call__(self, features): - return self._input_layer_template( - features=features, - feature_columns=self._feature_columns, - weight_collections=self._weight_collections, - trainable=self._trainable, - cols_to_vars=None, - scope=self._scope) + Args: + feature_columns: An iterable containing the FeatureColumns to use as + inputs to your model. All items should be instances of classes derived + from `DenseColumn` such as `numeric_column`, `embedding_column`, + `bucketized_column`, `indicator_column`. If you have categorical + features, you can wrap them with an `embedding_column` or + `indicator_column`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: Name to give to the FeatureLayer. + shared_state_manager: SharedEmbeddingStateManager that manages the state + of SharedEmbeddingColumns. The state of SharedEmbeddingColumns, unlike + regular embedding columns cannot be owned by the InputLayer itself since + SharedEmbeddingColumns can be shared across different InputLayers. As a + result users are expected to create a SharedEmbeddingStateManager object + which would be responsible for managing the shared state and can be + passed into different InputLayer objects to share state. For example, + + ```python + sc_1, sc_2 = shared_embedding_column_v2(...) + sc_3, sc_4 = shared_embedding_column_v2(...) + ssm = SharedEmbeddingStateManager() + feature_layer1 = FeatureLayer([sc_1, sc_3], ..., + shared_state_manager=ssm) + feature_layer2 = FeatureLayer([sc_2, sc_4], ..., + shared_state_manager=ssm) + ``` + now input_layer1 and input_layer2 will share variables across. If + sharing is not desired, one can create 2 separate + SharedEmbeddingStateManager objects + + ```python + ssm1 = SharedEmbeddingStateManager() + ssm2 = SharedEmbeddingStateManager() + feature_layer1 = FeatureLayer([sc_1, sc_3], ..., + shared_state_manager=ssm1) + feature_layer2 = FeatureLayer([sc_2, sc_4], ..., + shared_state_manager=ssm2) + ``` + **kwargs: Keyword arguments to construct a layer. - @property - def non_trainable_variables(self): - return self._input_layer_template.non_trainable_variables + Raises: + ValueError: if an item in `feature_columns` is not a `DenseColumn`. + """ + super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs) - @property - def non_trainable_weights(self): - return self._input_layer_template.non_trainable_weights + self._feature_columns = _normalize_feature_columns(feature_columns) + self._state_manager = _InputLayerStateManager(self, self._feature_columns, + self.trainable) + self._shared_state_manager = shared_state_manager + for column in sorted(self._feature_columns, key=lambda x: x.name): + if not isinstance(column, DenseColumn): + raise ValueError( + 'Items of feature_columns must be a DenseColumn. ' + 'You can wrap a categorical column with an ' + 'embedding_column or indicator_column. Given: {}'.format(column)) - @property - def trainable_variables(self): - return self._input_layer_template.trainable_variables + def build(self, _): + for column in sorted(self._feature_columns, key=lambda x: x.name): + if isinstance(column, SharedEmbeddingColumn): + column.create_state(self._shared_state_manager) + else: + with variable_scope.variable_scope(None, default_name=self.name): + column.create_state(self._state_manager) + super(FeatureLayer, self).build(None) - @property - def trainable_weights(self): - return self._input_layer_template.trainable_weights + def call(self, features, cols_to_output_tensors=None): + """Returns a dense tensor corresponding to the `feature_columns`. - @property - def variables(self): - return self._input_layer_template.variables + Args: + features: A mapping from key to tensors. `FeatureColumn`s look up via + these keys. For example `numeric_column('price')` will look at 'price' + key in this dict. Values can be a `SparseTensor` or a `Tensor` depends + on corresponding `FeatureColumn`. + cols_to_output_tensors: If not `None`, this will be filled with a dict + mapping feature columns to output tensors created. - @property - def weights(self): - return self._input_layer_template.weights + Returns: + A `Tensor` which represents input layer of a model. Its shape + is (batch_size, first_layer_dimension) and its dtype is `float32`. + first_layer_dimension is determined based on given `feature_columns`. + """ + transformation_cache = FeatureTransformationCache(features) + output_tensors = [] + ordered_columns = [] + for column in sorted(self._feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + if isinstance(column, SharedEmbeddingColumn): + tensor = column.get_dense_tensor(transformation_cache, + self._shared_state_manager) + else: + tensor = column.get_dense_tensor(transformation_cache, + self._state_manager) + num_elements = column.variable_shape.num_elements() + batch_size = array_ops.shape(tensor)[0] + tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) + output_tensors.append(tensor) + if cols_to_output_tensors is not None: + cols_to_output_tensors[column] = tensor + + _verify_static_batch_size_equality(output_tensors, ordered_columns) + return array_ops.concat(output_tensors, 1) def linear_model(features, @@ -565,12 +663,15 @@ class _BiasLayer(base.Layer): return self._bias_variable -def _get_expanded_variable_list(variable): - if (isinstance(variable, variables.Variable) or - resource_variable_ops.is_resource_variable(variable)): - return [variable] # Single variable case. - else: # Must be a PartitionedVariable, so convert into a list. - return list(variable) +def _get_expanded_variable_list(var_list): + returned_list = [] + for variable in var_list: + if (isinstance(variable, variables.Variable) or + resource_variable_ops.is_resource_variable(variable)): + returned_list.append(variable) # Single variable case. + else: # Must be a PartitionedVariable, so convert into a list. + returned_list.extend(list(variable)) + return returned_list def _strip_leading_slashes(name): @@ -661,7 +762,7 @@ class _LinearModel(training.Model): scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable name='weighted_sum') bias = self._bias_layer.variables[0] - self._cols_to_vars['bias'] = _get_expanded_variable_list(bias) + self._cols_to_vars['bias'] = _get_expanded_variable_list([bias]) return predictions def _add_layers(self, layers): @@ -877,10 +978,15 @@ def embedding_column( trainable=trainable) -def shared_embedding_columns( - categorical_columns, dimension, combiner='mean', initializer=None, - shared_embedding_collection_name=None, ckpt_to_load_from=None, - tensor_name_in_ckpt=None, max_norm=None, trainable=True): +def shared_embedding_columns_v2(categorical_columns, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): """List of dense columns that convert from sparse, categorical input. This is similar to `embedding_column`, except that it produces a list of @@ -1803,51 +1909,6 @@ def crossed_column(keys, hash_bucket_size, hash_key=None): keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) -class StateManager(object): - """Manages the state associated with FeatureColumns. - - Some `FeatureColumn`s create variables or resources to assist their - computation. The `StateManager` is responsible for creating and storing these - objects since `FeatureColumn`s are supposed to be stateless configuration - only. - """ - - def get_variable(self, - feature_column, - name, - shape, - dtype=None, - initializer=None): - """Creates a new variable or returns an existing one. - - Args: - feature_column: A `FeatureColumn` object this variable corresponds to. - name: variable name. - shape: variable shape. - dtype: The type of the variable. Defaults to `self.dtype` or `float32`. - initializer: initializer instance (callable). - - Returns: - The variable. - """ - raise NotImplementedError('StateManager.get_variable') - - def get_resource(self, feature_column, name, resource_creator): - """Creates a new resource or returns an existing one. - - Resources can be things such as tables etc. - - Args: - feature_column: A `FeatureColumn` object this variable corresponds to. - name: Name of the resource. - resource_creator: A callable that can create the resource. - - Returns: - The resource. - """ - raise NotImplementedError('StateManager.get_resource') - - class FeatureColumn(object): """Represents a feature column abstraction. @@ -2550,6 +2611,17 @@ class EmbeddingColumn( """See `DenseColumn` base class.""" return tensor_shape.vector(self.dimension) + def create_state(self, state_manager): + """Creates the embedding lookup variable.""" + embedding_shape = (self.categorical_column.num_buckets, self.dimension) + state_manager.create_variable( + self, + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + trainable=self.trainable, + initializer=self.initializer) + def _get_dense_tensor_internal(self, transformation_cache, state_manager): """Private method that follows the signature of _get_dense_tensor.""" # Get sparse IDs and weights. @@ -2558,13 +2630,8 @@ class EmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_shape = (self.categorical_column.num_buckets, self.dimension) embedding_weights = state_manager.get_variable( - self, - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer) + self, name='embedding_weights') if self.ckpt_to_load_from is not None: to_restore = embedding_weights @@ -2637,6 +2704,68 @@ def _get_graph_for_variable(var): return var.graph +class SharedEmbeddingStateManager(Layer): + """A state manager that handle the state of shared embedding columns. + + This can handle multiple sets of columns that share variables.""" + + def __init__(self, trainable=True, name=None, **kwargs): + """Constructs a `SharedEmbeddingStateManager`. + + Args: + trainable: If true, variables created are trainable. + name: Name of the State Manager. + **kwargs: Keyword arguments. + """ + super(SharedEmbeddingStateManager, self).__init__( + name=name, trainable=trainable, **kwargs) + self._var_dict = {} + + def create_variable(self, + name, + shape, + dtype=None, + trainable=True, + initializer=None): + """Creates a variable. + + Makes sure only one var is created per `shared_collection_name`. `name` is + ignored here as the variable is named `shared_collection_name` instead. + + Args: + name: Name of the variable. Not used. + shape: Variable shape. + dtype: Variable type. + trainable: If variable created should be trainable or not. + initializer: Variable initializer. + + Returns: + A variable or partitioned variable. + """ + if name in self._var_dict: + var = self._var_dict[name] + return var + with variable_scope.variable_scope( + self.name, reuse=variable_scope.AUTO_REUSE): + var = self.add_variable( + name=name, + shape=shape, + dtype=dtype, + trainable=self.trainable and trainable, + initializer=initializer, + # TODO(rohanj): Get rid of this hack once we have a mechanism for + # specifying a default partitioner for an entire layer. In that case, + # the default getter for Layers should work. + getter=variable_scope.get_variable) + self._var_dict[name] = var + return var + + def get_variable(self, feature_column, name): + if name not in self._var_dict: + raise ValueError('Variable name: {} not recognized.'.format(name)) + return self._var_dict[name] + + class SharedEmbeddingColumn( DenseColumn, SequenceDenseColumn, collections.namedtuple( @@ -2675,6 +2804,16 @@ class SharedEmbeddingColumn( """See `DenseColumn` base class.""" return tensor_shape.vector(self.dimension) + def create_state(self, state_manager): + """Creates the shared embedding lookup variable.""" + embedding_shape = (self.categorical_column.num_buckets, self.dimension) + state_manager.create_variable( + name=self.shared_collection_name, + shape=embedding_shape, + dtype=dtypes.float32, + trainable=self.trainable, + initializer=self.initializer) + def _get_dense_tensor_internal(self, transformation_cache, state_manager): """Private method that follows the signature of _get_dense_tensor.""" # This method is called from a variable_scope with name _var_scope_name, @@ -2687,13 +2826,8 @@ class SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_shape = (self.categorical_column.num_buckets, self.dimension) embedding_weights = state_manager.get_variable( - self, - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer) + self, name=self.shared_collection_name) if self.ckpt_to_load_from is not None: to_restore = embedding_weights diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index ad578d287a..6b343ecf3e 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -33,12 +33,12 @@ from tensorflow.python.eager import context from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column.feature_column_v2 import _LinearModel +from tensorflow.python.feature_column.feature_column_v2 import _transform_features from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn +from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache -from tensorflow.python.feature_column.feature_column_v2 import InputLayer from tensorflow.python.feature_column.feature_column_v2 import StateManager -from tensorflow.python.feature_column.feature_column_v2 import _LinearModel -from tensorflow.python.feature_column.feature_column_v2 import _transform_features from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -824,22 +824,6 @@ class HashedCategoricalColumnTest(test.TestCase): self.assertEqual( transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor) - def DISABLED_test_get_sparse_tensors_weight_collections(self): - column = fc.categorical_column_with_hash_bucket('aaa', 10) - inputs = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - column._get_sparse_tensors( - FeatureTransformationCache({ - 'aaa': inputs - }), - weight_collections=('my_weights',)) - - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) - def test_get_sparse_tensors_dense_input(self): hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) transformation_cache = FeatureTransformationCache({ @@ -2640,13 +2624,13 @@ class _LinearModelTest(test.TestCase): sess.run(net, feed_dict={features['price']: np.array(1)}) -class InputLayerTest(test.TestCase): +class FeatureLayerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def test_retrieving_input(self): features = {'a': [0.]} - input_layer = InputLayer(fc_old.numeric_column('a')) - inputs = self.evaluate(input_layer(features)) + feature_layer = FeatureLayer(fc.numeric_column('a')) + inputs = self.evaluate(feature_layer(features)) self.assertAllClose([[0.]], inputs) def test_reuses_variables(self): @@ -2657,7 +2641,7 @@ class InputLayerTest(test.TestCase): dense_shape=(3, 3)) # Create feature columns (categorical and embedding). - categorical_column = fc_old.categorical_column_with_identity( + categorical_column = fc.categorical_column_with_identity( key='a', num_buckets=3) embedding_dimension = 2 def _embedding_column_initializer(shape, dtype, partition_info): @@ -2670,16 +2654,16 @@ class InputLayerTest(test.TestCase): (1, 1)) # id 2 return embedding_values - embedding_column = fc_old.embedding_column( + embedding_column = fc.embedding_column( categorical_column, dimension=embedding_dimension, initializer=_embedding_column_initializer) - input_layer = InputLayer([embedding_column]) + feature_layer = FeatureLayer([embedding_column]) features = {'a': sparse_input} - inputs = input_layer(features) - variables = input_layer.variables + inputs = feature_layer(features) + variables = feature_layer.variables # Sanity check: test that the inputs are correct. self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs) @@ -2687,13 +2671,13 @@ class InputLayerTest(test.TestCase): # Check that only one variable was created. self.assertEqual(1, len(variables)) - # Check that invoking input_layer on the same features does not create + # Check that invoking feature_layer on the same features does not create # additional variables - _ = input_layer(features) + _ = feature_layer(features) self.assertEqual(1, len(variables)) - self.assertEqual(variables[0], input_layer.variables[0]) + self.assertEqual(variables[0], feature_layer.variables[0]) - def test_feature_column_input_layer_gradient(self): + def test_feature_column_feature_layer_gradient(self): with context.eager_mode(): sparse_input = sparse_tensor.SparseTensor( indices=((0, 0), (1, 0), (2, 0)), @@ -2701,7 +2685,7 @@ class InputLayerTest(test.TestCase): dense_shape=(3, 3)) # Create feature columns (categorical and embedding). - categorical_column = fc_old.categorical_column_with_identity( + categorical_column = fc.categorical_column_with_identity( key='a', num_buckets=3) embedding_dimension = 2 @@ -2715,16 +2699,16 @@ class InputLayerTest(test.TestCase): (1, 1)) # id 2 return embedding_values - embedding_column = fc_old.embedding_column( + embedding_column = fc.embedding_column( categorical_column, dimension=embedding_dimension, initializer=_embedding_column_initializer) - input_layer = InputLayer([embedding_column]) + feature_layer = FeatureLayer([embedding_column]) features = {'a': sparse_input} def scale_matrix(): - matrix = input_layer(features) + matrix = feature_layer(features) return 2 * matrix # Sanity check: Verify that scale_matrix returns the correct output. @@ -2739,185 +2723,139 @@ class InputLayerTest(test.TestCase): self.assertAllEqual([0, 1, 2], indexed_slice.indices) self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient) - -class FunctionalInputLayerTest(test.TestCase): - def test_raises_if_empty_feature_columns(self): with self.assertRaisesRegexp(ValueError, 'feature_columns must not be empty'): - fc.input_layer(features={}, feature_columns=[]) + FeatureLayer(feature_columns=[])(features={}) def test_should_be_dense_column(self): - with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'): - fc.input_layer( - features={'a': [[0]]}, - feature_columns=[ - fc_old.categorical_column_with_hash_bucket('wire_cast', 4) - ]) + with self.assertRaisesRegexp(ValueError, 'must be a DenseColumn'): + FeatureLayer(feature_columns=[ + fc.categorical_column_with_hash_bucket('wire_cast', 4) + ])( + features={ + 'a': [[0]] + }) def test_does_not_support_dict_columns(self): with self.assertRaisesRegexp( ValueError, 'Expected feature_columns to be iterable, found dict.'): - fc.input_layer( - features={'a': [[0]]}, - feature_columns={'a': fc_old.numeric_column('a')}) + FeatureLayer(feature_columns={'a': fc.numeric_column('a')})( + features={ + 'a': [[0]] + }) def test_bare_column(self): with ops.Graph().as_default(): features = features = {'a': [0.]} - net = fc.input_layer(features, fc_old.numeric_column('a')) + net = FeatureLayer(fc.numeric_column('a'))(features) with _initialized_session(): self.assertAllClose([[0.]], net.eval()) def test_column_generator(self): with ops.Graph().as_default(): features = features = {'a': [0.], 'b': [1.]} - columns = (fc_old.numeric_column(key) for key in features) - net = fc.input_layer(features, columns) + columns = (fc.numeric_column(key) for key in features) + net = FeatureLayer(columns)(features) with _initialized_session(): self.assertAllClose([[0., 1.]], net.eval()) def test_raises_if_duplicate_name(self): with self.assertRaisesRegexp( ValueError, 'Duplicate feature column name found for columns'): - fc.input_layer( - features={'a': [[0]]}, - feature_columns=[ - fc_old.numeric_column('a'), - fc_old.numeric_column('a') - ]) + FeatureLayer( + feature_columns=[fc.numeric_column('a'), + fc.numeric_column('a')])( + features={ + 'a': [[0]] + }) def test_one_column(self): - price = fc_old.numeric_column('price') + price = fc.numeric_column('price') with ops.Graph().as_default(): features = {'price': [[1.], [5.]]} - net = fc.input_layer(features, [price]) + net = FeatureLayer([price])(features) with _initialized_session(): self.assertAllClose([[1.], [5.]], net.eval()) def test_multi_dimension(self): - price = fc_old.numeric_column('price', shape=2) + price = fc.numeric_column('price', shape=2) with ops.Graph().as_default(): features = {'price': [[1., 2.], [5., 6.]]} - net = fc.input_layer(features, [price]) + net = FeatureLayer([price])(features) with _initialized_session(): self.assertAllClose([[1., 2.], [5., 6.]], net.eval()) def test_raises_if_shape_mismatch(self): - price = fc_old.numeric_column('price', shape=2) + price = fc.numeric_column('price', shape=2) with ops.Graph().as_default(): features = {'price': [[1.], [5.]]} with self.assertRaisesRegexp( Exception, r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): - fc.input_layer(features, [price]) + FeatureLayer([price])(features) def test_reshaping(self): - price = fc_old.numeric_column('price', shape=[1, 2]) + price = fc.numeric_column('price', shape=[1, 2]) with ops.Graph().as_default(): features = {'price': [[[1., 2.]], [[5., 6.]]]} - net = fc.input_layer(features, [price]) + net = FeatureLayer([price])(features) with _initialized_session(): self.assertAllClose([[1., 2.], [5., 6.]], net.eval()) def test_multi_column(self): - price1 = fc_old.numeric_column('price1', shape=2) - price2 = fc_old.numeric_column('price2') + price1 = fc.numeric_column('price1', shape=2) + price2 = fc.numeric_column('price2') with ops.Graph().as_default(): features = { 'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]] } - net = fc.input_layer(features, [price1, price2]) + net = FeatureLayer([price1, price2])(features) with _initialized_session(): self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval()) - def test_fills_cols_to_vars(self): - # Provide three _DenseColumn's to input_layer: a _NumericColumn, a - # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn - # creates a Variable. - price1 = fc_old.numeric_column('price1') - dense_feature = fc_old.numeric_column('dense_feature') - dense_feature_bucketized = fc_old.bucketized_column( - dense_feature, boundaries=[0.]) - some_sparse_column = fc_old.categorical_column_with_hash_bucket( - 'sparse_feature', hash_bucket_size=5) - some_embedding_column = fc_old.embedding_column( - some_sparse_column, dimension=10) - with ops.Graph().as_default(): - features = { - 'price1': [[3.], [4.]], - 'dense_feature': [[-1.], [4.]], - 'sparse_feature': [['a'], ['x']], - } - cols_to_vars = {} - all_cols = [price1, dense_feature_bucketized, some_embedding_column] - fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) - self.assertEqual(0, len(cols_to_vars[price1])) - self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) - self.assertEqual(1, len(cols_to_vars[some_embedding_column])) - self.assertIsInstance(cols_to_vars[some_embedding_column][0], - variables_lib.Variable) - self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10]) - - def test_fills_cols_to_vars_partitioned_variables(self): - price1 = fc_old.numeric_column('price1') - dense_feature = fc_old.numeric_column('dense_feature') - dense_feature_bucketized = fc_old.bucketized_column( - dense_feature, boundaries=[0.]) - some_sparse_column = fc_old.categorical_column_with_hash_bucket( - 'sparse_feature', hash_bucket_size=5) - some_embedding_column = fc_old.embedding_column( - some_sparse_column, dimension=10) + def test_cols_to_output_tensors(self): + price1 = fc.numeric_column('price1', shape=2) + price2 = fc.numeric_column('price2') with ops.Graph().as_default(): - features = { - 'price1': [[3.], [4.]], - 'dense_feature': [[-1.], [4.]], - 'sparse_feature': [['a'], ['x']], - } - cols_to_vars = {} - all_cols = [price1, dense_feature_bucketized, some_embedding_column] - with variable_scope.variable_scope( - 'input_from_feature_columns', - partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)): - fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) - self.assertEqual(0, len(cols_to_vars[price1])) - self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) - self.assertEqual(3, len(cols_to_vars[some_embedding_column])) - self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10]) - self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10]) - self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10]) + cols_dict = {} + features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} + feature_layer = FeatureLayer([price1, price2]) + net = feature_layer(features, cols_dict) + with _initialized_session(): + self.assertAllClose([[1., 2.], [5., 6.]], cols_dict[price1].eval()) + self.assertAllClose([[3.], [4.]], cols_dict[price2].eval()) + self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval()) def test_column_order(self): - price_a = fc_old.numeric_column('price_a') - price_b = fc_old.numeric_column('price_b') + price_a = fc.numeric_column('price_a') + price_b = fc.numeric_column('price_b') with ops.Graph().as_default(): features = { 'price_a': [[1.]], 'price_b': [[3.]], } - net1 = fc.input_layer(features, [price_a, price_b]) - net2 = fc.input_layer(features, [price_b, price_a]) + net1 = FeatureLayer([price_a, price_b])(features) + net2 = FeatureLayer([price_b, price_a])(features) with _initialized_session(): self.assertAllClose([[1., 3.]], net1.eval()) self.assertAllClose([[1., 3.]], net2.eval()) def test_fails_for_categorical_column(self): - animal = fc_old.categorical_column_with_identity('animal', num_buckets=4) + animal = fc.categorical_column_with_identity('animal', num_buckets=4) with ops.Graph().as_default(): features = { 'animal': sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) } - with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'): - fc.input_layer(features, [animal]) + with self.assertRaisesRegexp(Exception, 'must be a DenseColumn'): + FeatureLayer([animal])(features) def test_static_batch_size_mismatch(self): - price1 = fc_old.numeric_column('price1') - price2 = fc_old.numeric_column('price2') + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') with ops.Graph().as_default(): features = { 'price1': [[1.], [5.], [7.]], # batchsize = 3 @@ -2926,12 +2864,12 @@ class FunctionalInputLayerTest(test.TestCase): with self.assertRaisesRegexp( ValueError, 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string - fc.input_layer(features, [price1, price2]) + FeatureLayer([price1, price2])(features) def test_subset_of_static_batch_size_mismatch(self): - price1 = fc_old.numeric_column('price1') - price2 = fc_old.numeric_column('price2') - price3 = fc_old.numeric_column('price3') + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') + price3 = fc.numeric_column('price3') with ops.Graph().as_default(): features = { 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 @@ -2941,31 +2879,31 @@ class FunctionalInputLayerTest(test.TestCase): with self.assertRaisesRegexp( ValueError, 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string - fc.input_layer(features, [price1, price2, price3]) + FeatureLayer([price1, price2, price3])(features) def test_runtime_batch_size_mismatch(self): - price1 = fc_old.numeric_column('price1') - price2 = fc_old.numeric_column('price2') + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') with ops.Graph().as_default(): features = { 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 'price2': [[3.], [4.]] # batchsize = 2 } - net = fc.input_layer(features, [price1, price2]) + net = FeatureLayer([price1, price2])(features) with _initialized_session() as sess: with self.assertRaisesRegexp(errors.OpError, 'Dimensions of inputs should match'): sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]}) def test_runtime_batch_size_matches(self): - price1 = fc_old.numeric_column('price1') - price2 = fc_old.numeric_column('price2') + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') with ops.Graph().as_default(): features = { 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 } - net = fc.input_layer(features, [price1, price2]) + net = FeatureLayer([price1, price2])(features) with _initialized_session() as sess: sess.run( net, @@ -2975,9 +2913,9 @@ class FunctionalInputLayerTest(test.TestCase): }) def test_multiple_layers_with_same_embedding_column(self): - some_sparse_column = fc_old.categorical_column_with_hash_bucket( + some_sparse_column = fc.categorical_column_with_hash_bucket( 'sparse_feature', hash_bucket_size=5) - some_embedding_column = fc_old.embedding_column( + some_embedding_column = fc.embedding_column( some_sparse_column, dimension=10) with ops.Graph().as_default(): @@ -2985,28 +2923,30 @@ class FunctionalInputLayerTest(test.TestCase): 'sparse_feature': [['a'], ['x']], } all_cols = [some_embedding_column] - fc.input_layer(features, all_cols) - fc.input_layer(features, all_cols) + FeatureLayer(all_cols)(features) + FeatureLayer(all_cols)(features) # Make sure that 2 variables get created in this case. self.assertEqual(2, len( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) expected_var_names = [ - 'input_layer/sparse_feature_embedding/embedding_weights:0', - 'input_layer_1/sparse_feature_embedding/embedding_weights:0' + 'feature_layer/sparse_feature_embedding/embedding_weights:0', + 'feature_layer_1/sparse_feature_embedding/embedding_weights:0' ] self.assertItemsEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) def test_multiple_layers_with_same_shared_embedding_column(self): - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=3) - categorical_column_b = fc_old.categorical_column_with_identity( + categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns( + embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2( [categorical_column_b, categorical_column_a], dimension=embedding_dimension) + shared_state_manager = fc.SharedEmbeddingStateManager( + name='shared_feature_layer') with ops.Graph().as_default(): features = { @@ -3022,27 +2962,33 @@ class FunctionalInputLayerTest(test.TestCase): dense_shape=(2, 2)), } all_cols = [embedding_column_a, embedding_column_b] - fc.input_layer(features, all_cols) - fc.input_layer(features, all_cols) + FeatureLayer( + all_cols, shared_state_manager=shared_state_manager)( + features) + FeatureLayer( + all_cols, shared_state_manager=shared_state_manager)( + features) # Make sure that only 1 variable gets created in this case. self.assertEqual(1, len( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) self.assertItemsEqual( - ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + ['shared_feature_layer/aaa_bbb_shared_embedding:0'], [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self): - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=3) - categorical_column_b = fc_old.categorical_column_with_identity( + categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns( + embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2( [categorical_column_b, categorical_column_a], dimension=embedding_dimension) all_cols = [embedding_column_a, embedding_column_b] with ops.Graph().as_default(): + shared_state_manager1 = fc.SharedEmbeddingStateManager( + name='shared_feature_layer') features = { 'aaa': sparse_tensor.SparseTensor( @@ -3055,12 +3001,16 @@ class FunctionalInputLayerTest(test.TestCase): values=(1, 2, 1), dense_shape=(2, 2)), } - fc.input_layer(features, all_cols) + FeatureLayer( + all_cols, shared_state_manager=shared_state_manager1)( + features) # Make sure that only 1 variable gets created in this case. self.assertEqual(1, len( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) with ops.Graph().as_default(): + shared_state_manager2 = fc.SharedEmbeddingStateManager( + name='shared_feature_layer') features1 = { 'aaa': sparse_tensor.SparseTensor( @@ -3074,12 +3024,14 @@ class FunctionalInputLayerTest(test.TestCase): dense_shape=(2, 2)), } - fc.input_layer(features1, all_cols) + FeatureLayer( + all_cols, shared_state_manager=shared_state_manager2)( + features1) # Make sure that only 1 variable gets created in this case. self.assertEqual(1, len( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) self.assertItemsEqual( - ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + ['shared_feature_layer/aaa_bbb_shared_embedding:0'], [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) def test_with_numpy_input_fn(self): @@ -3092,14 +3044,14 @@ class FunctionalInputLayerTest(test.TestCase): del shape, dtype, partition_info return embedding_values - # price has 1 dimension in input_layer - price = fc_old.numeric_column('price') - body_style = fc_old.categorical_column_with_vocabulary_list( + # price has 1 dimension in feature_layer + price = fc.numeric_column('price') + body_style = fc.categorical_column_with_vocabulary_list( 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - # one_hot_body_style has 3 dims in input_layer. - one_hot_body_style = fc_old.indicator_column(body_style) - # embedded_body_style has 5 dims in input_layer. - embedded_body_style = fc_old.embedding_column( + # one_hot_body_style has 3 dims in feature_layer. + one_hot_body_style = fc.indicator_column(body_style) + # embedded_body_style has 5 dims in feature_layer. + embedded_body_style = fc.embedding_column( body_style, dimension=5, initializer=_initializer) input_fn = numpy_io.numpy_input_fn( @@ -3110,8 +3062,8 @@ class FunctionalInputLayerTest(test.TestCase): batch_size=2, shuffle=False) features = input_fn() - net = fc.input_layer(features, - [price, one_hot_body_style, embedded_body_style]) + net = FeatureLayer([price, one_hot_body_style, embedded_body_style])( + features) self.assertEqual(1 + 3 + 5, net.shape[1]) with _initialized_session() as sess: coord = coordinator.Coordinator() @@ -3137,18 +3089,18 @@ class FunctionalInputLayerTest(test.TestCase): del shape, dtype, partition_info return embedding_values - # price has 1 dimension in input_layer - price = fc_old.numeric_column('price') + # price has 1 dimension in feature_layer + price = fc.numeric_column('price') - # one_hot_body_style has 3 dims in input_layer. - body_style = fc_old.categorical_column_with_vocabulary_list( + # one_hot_body_style has 3 dims in feature_layer. + body_style = fc.categorical_column_with_vocabulary_list( 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - one_hot_body_style = fc_old.indicator_column(body_style) + one_hot_body_style = fc.indicator_column(body_style) - # embedded_body_style has 5 dims in input_layer. - country = fc_old.categorical_column_with_vocabulary_list( + # embedded_body_style has 5 dims in feature_layer. + country = fc.categorical_column_with_vocabulary_list( 'country', vocabulary_list=['US', 'JP', 'CA']) - embedded_country = fc_old.embedding_column( + embedded_country = fc.embedding_column( country, dimension=5, initializer=_initializer) # Provides 1-dim tensor and dense tensor. @@ -3165,8 +3117,7 @@ class FunctionalInputLayerTest(test.TestCase): self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0]) self.assertEqual(1, features['country'].shape.ndims) - net = fc.input_layer(features, - [price, one_hot_body_style, embedded_country]) + net = FeatureLayer([price, one_hot_body_style, embedded_country])(features) self.assertEqual(1 + 3 + 5, net.shape[1]) with _initialized_session() as sess: @@ -3187,18 +3138,18 @@ class FunctionalInputLayerTest(test.TestCase): del shape, dtype, partition_info return embedding_values - # price has 1 dimension in input_layer - price = fc_old.numeric_column('price') + # price has 1 dimension in feature_layer + price = fc.numeric_column('price') - # one_hot_body_style has 3 dims in input_layer. - body_style = fc_old.categorical_column_with_vocabulary_list( + # one_hot_body_style has 3 dims in feature_layer. + body_style = fc.categorical_column_with_vocabulary_list( 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - one_hot_body_style = fc_old.indicator_column(body_style) + one_hot_body_style = fc.indicator_column(body_style) - # embedded_body_style has 5 dims in input_layer. - country = fc_old.categorical_column_with_vocabulary_list( + # embedded_body_style has 5 dims in feature_layer. + country = fc.categorical_column_with_vocabulary_list( 'country', vocabulary_list=['US', 'JP', 'CA']) - embedded_country = fc_old.embedding_column( + embedded_country = fc.embedding_column( country, dimension=2, initializer=_initializer) # Provides 1-dim tensor and dense tensor. @@ -3219,8 +3170,7 @@ class FunctionalInputLayerTest(test.TestCase): dense_shape=(2,)) country_data = np.array([['US'], ['CA']]) - net = fc.input_layer(features, - [price, one_hot_body_style, embedded_country]) + net = FeatureLayer([price, one_hot_body_style, embedded_country])(features) self.assertEqual(1 + 3 + 2, net.shape[1]) with _initialized_session() as sess: @@ -3237,8 +3187,8 @@ class FunctionalInputLayerTest(test.TestCase): })) def test_with_rank_0_feature(self): - # price has 1 dimension in input_layer - price = fc_old.numeric_column('price') + # price has 1 dimension in feature_layer + price = fc.numeric_column('price') features = { 'price': constant_op.constant(0), } @@ -3246,13 +3196,13 @@ class FunctionalInputLayerTest(test.TestCase): # Static rank 0 should fail with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'): - fc.input_layer(features, [price]) + FeatureLayer([price])(features) # Dynamic rank 0 should fail features = { 'price': array_ops.placeholder(dtypes.float32), } - net = fc.input_layer(features, [price]) + net = FeatureLayer([price])(features) self.assertEqual(1, net.shape[1]) with _initialized_session() as sess: with self.assertRaisesOpError('Feature .* cannot have rank 0'): @@ -3267,7 +3217,7 @@ class MakeParseExampleSpecTest(test.TestCase): @property def name(self): - return "_TestFeatureColumn" + return '_TestFeatureColumn' def transform_feature(self, transformation_cache, state_manager): pass @@ -3593,25 +3543,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_tensor.eval()) - def DISABLED_test_get_sparse_tensors_weight_collections(self): - column = fc.categorical_column_with_vocabulary_file( - key='aaa', - vocabulary_file=self._wire_vocabulary_file_name, - vocabulary_size=self._wire_vocabulary_size) - inputs = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - column.get_sparse_tensors( - FeatureTransformationCache({ - 'aaa': inputs - }), - weight_collections=('my_weights',)) - - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) - def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', @@ -4043,24 +3974,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_tensor.eval()) - def DISABLED_test_get_sparse_tensors_weight_collections(self): - column = fc.categorical_column_with_vocabulary_list( - key='aaa', - vocabulary_list=('omar', 'stringer', 'marlo')) - inputs = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - column.get_sparse_tensors( - FeatureTransformationCache({ - 'aaa': inputs - }), - weight_collections=('my_weights',)) - - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) - def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_vocabulary_list( key='aaa', @@ -4356,22 +4269,6 @@ class IdentityCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_tensor.eval()) - def DISABLED_test_get_sparse_tensors_weight_collections(self): - column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(0, 1, 0), - dense_shape=(2, 2)) - column.get_sparse_tensors( - FeatureTransformationCache({ - 'aaa': inputs - }), - weight_collections=('my_weights',)) - - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) - def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) id_weight_pair = column.get_sparse_tensors( @@ -4765,16 +4662,16 @@ class IndicatorColumnTest(test.TestCase): weight_var.assign([[1.], [2.], [3.], [4.]]).eval() self.assertAllClose([[2. + 3.]], predictions.eval()) - def test_input_layer(self): - animal = fc_old.indicator_column( - fc_old.categorical_column_with_identity('animal', num_buckets=4)) + def test_feature_layer(self): + animal = fc.indicator_column( + fc.categorical_column_with_identity('animal', num_buckets=4)) with ops.Graph().as_default(): features = { 'animal': sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) } - net = fc.input_layer(features, [animal]) + net = FeatureLayer([animal])(features) with _initialized_session(): self.assertAllClose([[0., 1., 1., 0.]], net.eval()) @@ -4786,12 +4683,13 @@ class _TestStateManager(StateManager): self._all_variables = {} self._trainable = trainable - def get_variable(self, - feature_column, - name, - shape, - dtype=None, - initializer=None): + def create_variable(self, + feature_column, + name, + shape, + dtype=None, + trainable=True, + initializer=None): if feature_column not in self._all_variables: self._all_variables[feature_column] = {} var_dict = self._all_variables[feature_column] @@ -4801,11 +4699,19 @@ class _TestStateManager(StateManager): var = variable_scope.get_variable( name=name, shape=shape, - initializer=initializer, - trainable=self._trainable) + dtype=dtype, + trainable=self._trainable and trainable, + initializer=initializer) var_dict[name] = var return var + def get_variable(self, feature_column, name): + if feature_column not in self._all_variables: + raise ValueError('Do not recognize FeatureColumn.') + if name in self._all_variables[feature_column]: + return self._all_variables[feature_column][name] + raise ValueError('Could not find variable.') + class EmbeddingColumnTest(test.TestCase): @@ -4967,6 +4873,7 @@ class EmbeddingColumnTest(test.TestCase): categorical_column, dimension=embedding_dimension, initializer=_initializer) state_manager = _TestStateManager() + embedding_column.create_state(state_manager) # Provide sparse input and get dense result. embedding_lookup = embedding_column.get_dense_tensor( @@ -5028,6 +4935,7 @@ class EmbeddingColumnTest(test.TestCase): categorical_column, dimension=embedding_dimension, initializer=_initializer) state_manager = _TestStateManager() + embedding_column.create_state(state_manager) # Provide sparse input and get dense result. embedding_lookup = embedding_column.get_dense_tensor( @@ -5043,36 +4951,6 @@ class EmbeddingColumnTest(test.TestCase): self.assertAllEqual(embedding_values, global_vars[0].eval()) self.assertAllEqual(expected_lookups, embedding_lookup.eval()) - def DISABLED_test_get_dense_tensor_weight_collections(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 4), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 5)) - - # Build columns. - categorical_column = fc.categorical_column_with_identity( - key='aaa', num_buckets=3) - embedding_column = fc.embedding_column(categorical_column, dimension=2) - - # Provide sparse input and get dense result. - embedding_column.get_dense_tensor( - FeatureTransformationCache({ - 'aaa': sparse_input - }), - weight_collections=('my_vars',)) - - # Assert expected embedding variable and lookups. - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), - tuple([v.name for v in global_vars])) - my_vars = ops.get_collection('my_vars') - self.assertItemsEqual( - ('embedding_weights:0',), tuple([v.name for v in my_vars])) - def test_get_dense_tensor_placeholder_inputs(self): # Inputs. vocabulary_size = 3 @@ -5117,6 +4995,7 @@ class EmbeddingColumnTest(test.TestCase): categorical_column, dimension=embedding_dimension, initializer=_initializer) state_manager = _TestStateManager() + embedding_column.create_state(state_manager) # Provide sparse input and get dense result. input_indices = array_ops.placeholder(dtype=dtypes.int64) @@ -5187,6 +5066,7 @@ class EmbeddingColumnTest(test.TestCase): ckpt_to_load_from=ckpt_path, tensor_name_in_ckpt=ckpt_tensor) state_manager = _TestStateManager() + embedding_column.create_state(state_manager) # Provide sparse input and get dense result. embedding_lookup = embedding_column.get_dense_tensor( @@ -5354,7 +5234,7 @@ class EmbeddingColumnTest(test.TestCase): # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval()) - def test_input_layer(self): + def test_feature_layer(self): # Inputs. vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( @@ -5392,30 +5272,29 @@ class EmbeddingColumnTest(test.TestCase): ) # Build columns. - categorical_column = fc_old.categorical_column_with_identity( + categorical_column = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( + embedding_column = fc.embedding_column( categorical_column, dimension=embedding_dimension, initializer=_initializer) # Provide sparse input and get dense result. - input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,)) + l = FeatureLayer((embedding_column,)) + feature_layer = l({'aaa': sparse_input}) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in global_vars])) + self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in trainable_vars])) + self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in trainable_vars])) with _initialized_session(): self.assertAllEqual(embedding_values, trainable_vars[0].eval()) - self.assertAllEqual(expected_lookups, input_layer.eval()) + self.assertAllEqual(expected_lookups, feature_layer.eval()) - def test_input_layer_not_trainable(self): + def test_feature_layer_not_trainable(self): # Inputs. vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( @@ -5453,65 +5332,26 @@ class EmbeddingColumnTest(test.TestCase): ) # Build columns. - categorical_column = fc_old.categorical_column_with_identity( + categorical_column = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( + embedding_column = fc.embedding_column( categorical_column, dimension=embedding_dimension, initializer=_initializer, trainable=False) # Provide sparse input and get dense result. - input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,)) + feature_layer = FeatureLayer((embedding_column,))({'aaa': sparse_input}) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in global_vars])) + self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) self.assertItemsEqual( [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) with _initialized_session(): self.assertAllEqual(embedding_values, global_vars[0].eval()) - self.assertAllEqual(expected_lookups, input_layer.eval()) - - -class _TestSharedEmbeddingStateManager(StateManager): - """Manages the state for shared embedding columns. - - This can handle multiple groups of shared embedding columns. - """ - - def __init__(self, trainable=True): - # Dict of shared_embedding_collection_name to a dict of variables. - self._all_variables = {} - self._trainable = trainable - - def get_variable(self, - feature_column, - name, - shape, - dtype=None, - initializer=None): - if not isinstance(feature_column, fc.SharedEmbeddingColumn): - raise ValueError( - 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. ' - 'Given type: {} '.format(type(feature_column))) - - collection_name = feature_column.shared_collection_name - if collection_name not in self._all_variables: - self._all_variables[collection_name] = {} - var_dict = self._all_variables[collection_name] - if name in var_dict: - return var_dict[name] - else: - var = variable_scope.get_variable( - name=name, - shape=shape, - initializer=initializer, - trainable=self._trainable) - var_dict[name] = var - return var + self.assertAllEqual(expected_lookups, feature_layer.eval()) class SharedEmbeddingColumnTest(test.TestCase): @@ -5522,7 +5362,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - embedding_column_b, embedding_column_a = fc.shared_embedding_columns( + embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2( [categorical_column_b, categorical_column_a], dimension=embedding_dimension) self.assertIs(categorical_column_a, embedding_column_a.categorical_column) @@ -5560,7 +5400,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, combiner='my_combiner', @@ -5605,7 +5445,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - original_a, _ = fc.shared_embedding_columns( + original_a, _ = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, combiner='my_combiner', @@ -5613,7 +5453,8 @@ class SharedEmbeddingColumnTest(test.TestCase): shared_embedding_collection_name='shared_embedding_collection_name', ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor', - max_norm=42., trainable=False) + max_norm=42., + trainable=False) for embedding_column_a in (original_a, copy.deepcopy(original_a)): self.assertEqual('aaa', embedding_column_a.categorical_column.name) self.assertEqual(3, embedding_column_a.categorical_column.num_buckets) @@ -5642,8 +5483,9 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): - fc.shared_embedding_columns( - [categorical_column_a, categorical_column_b], dimension=2, + fc.shared_embedding_columns_v2( + [categorical_column_a, categorical_column_b], + dimension=2, initializer='not_fn') def test_incompatible_column_type(self): @@ -5656,7 +5498,7 @@ class SharedEmbeddingColumnTest(test.TestCase): with self.assertRaisesRegexp( ValueError, 'all categorical_columns must have the same type.*' 'IdentityCategoricalColumn.*HashedCategoricalColumn'): - fc.shared_embedding_columns( + fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b, categorical_column_c], dimension=2) @@ -5669,11 +5511,11 @@ class SharedEmbeddingColumnTest(test.TestCase): key='bbb', num_buckets=3) weighted_categorical_column_b = fc.weighted_categorical_column( categorical_column_b, weight_feature_key='bbb_weights') - fc.shared_embedding_columns( + fc.shared_embedding_columns_v2( [weighted_categorical_column_a, categorical_column_b], dimension=2) - fc.shared_embedding_columns( + fc.shared_embedding_columns_v2( [categorical_column_a, weighted_categorical_column_b], dimension=2) - fc.shared_embedding_columns( + fc.shared_embedding_columns_v2( [weighted_categorical_column_a, weighted_categorical_column_b], dimension=2) @@ -5682,8 +5524,7 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) b = fc.categorical_column_with_vocabulary_list( key='bbb', vocabulary_list=('omar', 'stringer', 'marlo')) - a_embedded, b_embedded = fc.shared_embedding_columns( - [a, b], dimension=2) + a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2) data = example_pb2.Example(features=feature_pb2.Features( feature={ 'aaa': @@ -5717,8 +5558,7 @@ class SharedEmbeddingColumnTest(test.TestCase): def test_transform_feature(self): a = fc.categorical_column_with_identity(key='aaa', num_buckets=3) b = fc.categorical_column_with_identity(key='bbb', num_buckets=3) - a_embedded, b_embedded = fc.shared_embedding_columns( - [a, b], dimension=2) + a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2) features = { 'aaa': sparse_tensor.SparseTensor( indices=((0, 0), (1, 0), (1, 1)), @@ -5788,10 +5628,13 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, initializer=_initializer) - state_manager = _TestSharedEmbeddingStateManager() + dimension=embedding_dimension, + initializer=_initializer) + state_manager = fc.SharedEmbeddingStateManager(name='shared_feature_layer') + embedding_column_a.create_state(state_manager) + embedding_column_b.create_state(state_manager) # Provide sparse input and get dense result. embedding_lookup_a = embedding_column_a.get_dense_tensor( @@ -5801,7 +5644,7 @@ class SharedEmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertItemsEqual(('shared_feature_layer/aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) embedding_var = global_vars[0] with _initialized_session(): @@ -5809,58 +5652,6 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) - def DISABLED_test_get_dense_tensor_weight_collections(self): - # Inputs. - vocabulary_size = 3 - # -1 values are ignored. - input_a = np.array([ - [2, -1, -1], # example 0, ids [2] - [0, 1, -1] - ]) # example 1, ids [0, 1] - input_b = np.array([ - [0, -1, -1], # example 0, ids [0] - [-1, -1, -1] - ]) # example 1, ids [] - input_features = {'aaa': input_a, 'bbb': input_b} - - # Embedding variable. - embedding_dimension = 2 - embedding_values = ( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ) - - def _initializer(shape, dtype, partition_info): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return embedding_values - - # Build columns. - categorical_column_a = fc.categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc.categorical_column_with_identity( - key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc.shared_embedding_columns( - [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, - initializer=_initializer) - - fc.input_layer( - input_features, [embedding_column_a, embedding_column_b], - weight_collections=('my_vars',)) - - # Assert expected embedding variable and lookups. - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), - tuple(v.name for v in global_vars)) - my_vars = ops.get_collection('my_vars') - self.assertItemsEqual( - ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), - tuple(v.name for v in my_vars)) - def test_get_dense_tensor_placeholder_inputs(self): # Inputs. vocabulary_size = 3 @@ -5903,10 +5694,13 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, initializer=_initializer) - state_manager = _TestSharedEmbeddingStateManager() + dimension=embedding_dimension, + initializer=_initializer) + state_manager = fc.SharedEmbeddingStateManager() + embedding_column_a.create_state(state_manager) + embedding_column_b.create_state(state_manager) # Provide sparse input and get dense result. embedding_lookup_a = embedding_column_a.get_dense_tensor( @@ -6096,7 +5890,7 @@ class SharedEmbeddingColumnTest(test.TestCase): # = [3*1 + 5*2, 3*0 +5*0] = [13, 0] self.assertAllClose([[94. + 13.], [29.]], predictions.eval()) - def _test_input_layer(self, trainable=True): + def _test_feature_layer(self, trainable=True): # Inputs. vocabulary_size = 3 sparse_input_a = sparse_tensor.SparseTensorValue( @@ -6111,6 +5905,18 @@ class SharedEmbeddingColumnTest(test.TestCase): indices=((0, 0),), values=(0,), dense_shape=(2, 5)) + sparse_input_c = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 1), (1, 1), (1, 3)), + values=(2, 0, 1), + dense_shape=(2, 5)) + sparse_input_d = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + indices=((0, 1),), + values=(2,), + dense_shape=(2, 5)) # Embedding variable. embedding_dimension = 2 @@ -6130,51 +5936,127 @@ class SharedEmbeddingColumnTest(test.TestCase): # example 0: # A ids [2], embedding = [7, 11] # B ids [0], embedding = [1, 2] - (7., 11., 1., 2.), + # C ids [2], embedding = [7, 11] + # D ids [2], embedding = [7, 11] + (7., 11., 1., 2., 7., 11., 7., 11.), # example 1: # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] # B ids [], embedding = [0, 0] - (2., 3.5, 0., 0.), + # C ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + # D ids [], embedding = [0, 0] + (2., 3.5, 0., 0., 2., 3.5, 0., 0.), ) # Build columns. - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc_old.categorical_column_with_identity( + categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns( + categorical_column_c = fc.categorical_column_with_identity( + key='ccc', num_buckets=vocabulary_size) + categorical_column_d = fc.categorical_column_with_identity( + key='ddd', num_buckets=vocabulary_size) + + embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer, trainable=trainable) + embedding_column_c, embedding_column_d = fc.shared_embedding_columns_v2( + [categorical_column_c, categorical_column_d], + dimension=embedding_dimension, + initializer=_initializer, + trainable=trainable) + shared_state_manager = fc.SharedEmbeddingStateManager( + name='shared_feature_layer') + + features = { + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + 'ccc': sparse_input_c, + 'ddd': sparse_input_d + } # Provide sparse input and get dense result. - input_layer = fc.input_layer( - features={'aaa': sparse_input_a, 'bbb': sparse_input_b}, - feature_columns=(embedding_column_b, embedding_column_a)) + feature_layer = FeatureLayer( + feature_columns=(embedding_column_b, embedding_column_a, + embedding_column_c, embedding_column_d), + shared_state_manager=shared_state_manager)( + features) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], - tuple([v.name for v in global_vars])) + self.assertItemsEqual([ + 'shared_feature_layer/aaa_bbb_shared_embedding:0', + 'shared_feature_layer/ccc_ddd_shared_embedding:0' + ], tuple([v.name for v in global_vars])) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) if trainable: - self.assertItemsEqual( - ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], - tuple([v.name for v in trainable_vars])) + self.assertItemsEqual([ + 'shared_feature_layer/aaa_bbb_shared_embedding:0', + 'shared_feature_layer/ccc_ddd_shared_embedding:0' + ], tuple([v.name for v in trainable_vars])) else: self.assertItemsEqual([], tuple([v.name for v in trainable_vars])) shared_embedding_vars = global_vars with _initialized_session(): self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval()) - self.assertAllEqual(expected_lookups, input_layer.eval()) + self.assertAllEqual(expected_lookups, feature_layer.eval()) + + def test_feature_layer(self): + self._test_feature_layer() + + def test_feature_layer_no_trainable(self): + self._test_feature_layer(trainable=False) + - def test_input_layer(self): - self._test_input_layer() +class SharedEmbeddingStateManagerTest(test.TestCase): - def test_input_layer_no_trainable(self): - self._test_input_layer(trainable=False) + def test_basic(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + fc.shared_embedding_columns_v2( + [categorical_column_a, categorical_column_b], dimension=2) + shared_state_manager = fc.SharedEmbeddingStateManager( + name='shared_feature_layer') + var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding', + [5, 10]) + var_b = shared_state_manager.create_variable('aaa_bbb_shared_embedding', + [5, 10]) + self.assertEqual(var_a, var_b) + self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0', + var_a.name) + self.assertIsInstance(var_a, variables_lib.Variable) + + def test_multiple_sets(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + categorical_column_c = fc.categorical_column_with_identity( + key='ccc', num_buckets=3) + categorical_column_d = fc.categorical_column_with_identity( + key='ddd', num_buckets=3) + + fc.shared_embedding_columns_v2( + [categorical_column_a, categorical_column_b], dimension=2) + fc.shared_embedding_columns_v2( + [categorical_column_c, categorical_column_d], dimension=2) + shared_state_manager = fc.SharedEmbeddingStateManager( + name='shared_feature_layer') + var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding', + [5, 10]) + var_c = shared_state_manager.create_variable('ccc_ddd_shared_embedding', + [5, 10]) + self.assertIsInstance(var_a, variables_lib.Variable) + self.assertIsInstance(var_c, variables_lib.Variable) + self.assertNotEquals(var_a, var_c) + self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0', + var_a.name) + self.assertEqual('shared_feature_layer/ccc_ddd_shared_embedding:0', + var_c.name) class WeightedCategoricalColumnTest(test.TestCase): diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index f47c0d8a5e..a8aef3a009 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -23,7 +23,6 @@ from __future__ import print_function import collections import hashlib -import sys from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 @@ -34,7 +33,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat @@ -42,9 +40,6 @@ from tensorflow.python.util import function_utils from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_inspect -# This is to avoid a circular dependency with cond_v2_impl. -cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access - class Defun(object): """Decorator used to define TensorFlow functions. @@ -1029,20 +1024,10 @@ def _from_definition(fdef, grad_func=None): result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, python_grad_func, out_names) # pylint: disable=protected-access - if ops._USE_C_API: - serialized = fdef.SerializeToString() - c_func = c_api.TF_FunctionImportFunctionDef(serialized) - result._c_func = c_api_util.ScopedTFFunction(c_func) - result._extra_inputs = [] - else: - result._definition = fdef - # Captured inputs are added as regular inputs to a function when it's - # serialized, i.e. any extra inputs from the original function are now - # included in `result`._args - result._extra_inputs = [] - result._hash_str = result._create_hash_str( - result._definition.signature.input_arg, - result._definition.signature.output_arg, result._definition.node_def) + serialized = fdef.SerializeToString() + c_func = c_api.TF_FunctionImportFunctionDef(serialized) + result._c_func = c_api_util.ScopedTFFunction(c_func) + result._extra_inputs = [] # pylint: enable=protected-access return result diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py index 1b09506662..a04fa369ae 100644 --- a/tensorflow/python/framework/function_def_to_graph.py +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -23,7 +23,7 @@ import sys from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import versions_pb2 -from tensorflow.python.framework import function +from tensorflow.python.eager import function from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import versions @@ -34,13 +34,13 @@ cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=p def function_def_to_graph(fdef, input_shapes=None): - """Converts a FunctionDef to a function._FuncGraph (sub-class Graph). + """Converts a FunctionDef to a function.FuncGraph (sub-class Graph). - The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set. + The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. The input tensors are represented as placeholders. - Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be - set by the caller. + Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set + by the caller. Args: fdef: FunctionDef. @@ -50,9 +50,9 @@ def function_def_to_graph(fdef, input_shapes=None): placeholder will have unknown shape. Returns: - A _FuncGraph. + A FuncGraph. """ - func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access + func_graph = function.FuncGraph(fdef.signature.name) graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( fdef, input_shapes) @@ -60,7 +60,7 @@ def function_def_to_graph(fdef, input_shapes=None): # Add all function nodes to the graph. importer.import_graph_def(graph_def, name="") - # Initialize fields specific to _FuncGraph. + # Initialize fields specific to FuncGraph. # inputs input_tensor_names = [ @@ -144,6 +144,8 @@ def function_def_to_graph_def(fdef, input_shapes=None): for arg_def in fdef.signature.input_arg: nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) + control_name = "^" + arg_def.name + nested_to_flat_tensor_name[control_name] = control_name for node_def in fdef.node_def: op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access @@ -172,6 +174,8 @@ def function_def_to_graph_def(fdef, input_shapes=None): flat_name = "{}:{}".format(node_def.name, flattened_index) nested_to_flat_tensor_name[nested_name] = flat_name flattened_index += 1 + control_name = "^" + node_def.name + nested_to_flat_tensor_name[control_name] = control_name # Update inputs of all nodes in graph. for node_def in graph_def.node: diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py index 21d2c7d990..938814f1d0 100644 --- a/tensorflow/python/framework/function_def_to_graph_test.py +++ b/tensorflow/python/framework/function_def_to_graph_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops @@ -154,14 +154,20 @@ class FunctionDefToGraphDefTest(test.TestCase): self.assertDictEqual( tensor_name_map, { "x": "x:0", + "^x": "^x", "y": "y:0", + "^y": "^y", "z": "z:0", + "^z": "^z", "foo_1:d:0": "foo_1:0", "foo_1:e:0": "foo_1:1", + "^foo_1": "^foo_1", "list_output:a:0": "list_output:0", "list_output:a:1": "list_output:1", + "^list_output": "^list_output", "foo_2:d:0": "foo_2:0", "foo_2:e:0": "foo_2:1", + "^foo_2": "^foo_2", }) def testShapes(self): @@ -184,23 +190,25 @@ class FunctionDefToGraphDefTest(test.TestCase): x = constant_op.constant(5.0) y = constant_op.constant(10.0) - @function.Defun() + @function.defun def fn(): - @function.Defun() + @function.defun def inner_fn(): return x + y return inner_fn() - # Instantiate the function in this graph so that - # `function_def_to_graph` can find it. - fn() - def fn2(): return 2 * fn() - fdef = function._DefinedFunction(fn2, [], []).definition + fn2_defun = function.make_defun_op(fn2) + + # Call `fn2` to make sure `fn` is correctly instantiated so + # `function_def_to_graph` can find it. + fn2_defun() + + fdef = fn2_defun._inference_function.definition func_graph = function_def_to_graph.function_def_to_graph(fdef) with func_graph.as_default(): x_ph, y_ph = func_graph.inputs @@ -211,6 +219,25 @@ class FunctionDefToGraphDefTest(test.TestCase): y_ph: 10.0 }), 30.0) + def testControlDependencies(self): + + def fn(inp): + x = constant_op.constant(2.0, name="x") + # TODO(b/79881896): Test external control dependency once that's + # supported. + with ops.control_dependencies([x, inp]): + constant_op.constant(3.0, name="y") + return 4.0 + + inp = constant_op.constant(1.0) + fdef = function.make_defun_op(fn, inp)._inference_function.definition + func_graph = function_def_to_graph.function_def_to_graph(fdef) + + op = func_graph.get_operation_by_name("y") + self.assertEqual(len(op.control_inputs), 2) + self.assertEqual(op.control_inputs[0].name, "x") + self.assertEqual(op.control_inputs[1].name, "placeholder") + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 5ebe43ff93..8c85a422e7 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import copy -import os import re import sys import threading @@ -67,7 +66,7 @@ from tensorflow.python.util.tf_export import tf_export # Temporary global switches determining if we should enable the work-in-progress # calls to the C API. These will be removed once all functionality is supported. _USE_C_API = True -_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0" +_USE_C_SHAPES = True def tensor_id(tensor): @@ -2859,19 +2858,11 @@ class Graph(object): # TODO(skyewm): fold as much of the above as possible into the C # implementation - if self._use_c_api_hack(): - self._scoped_c_graph = c_api_util.ScopedTFGraph() - # The C API requires all ops to have shape functions. Disable this - # requirement (many custom ops do not have shape functions, and we don't - # want to break these existing cases). - c_api.SetRequireShapeInferenceFns(self._c_graph, False) - else: - self._scoped_c_graph = None - - # TODO(apassos) remove once the C API is used by default. - def _use_c_api_hack(self): - """Temporary hack; can be overridden to force C API usage.""" - return _USE_C_API + self._scoped_c_graph = c_api_util.ScopedTFGraph() + # The C API requires all ops to have shape functions. Disable this + # requirement (many custom ops do not have shape functions, and we don't + # want to break these existing cases). + c_api.SetRequireShapeInferenceFns(self._c_graph, False) # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. @@ -3121,7 +3112,7 @@ class Graph(object): Returns: bool indicating whether or not 'name' is registered in function library. """ - return name in self._functions + return compat.as_str(name) in self._functions def _get_function(self, name): """Returns the function definition for 'name'. @@ -3131,7 +3122,7 @@ class Graph(object): Returns: The function def proto. """ - return self._functions.get(name, None) + return self._functions.get(compat.as_str(name), None) def _add_function(self, function): """Adds a function to the graph. @@ -3167,7 +3158,7 @@ class Graph(object): c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) # pylint: enable=protected-access - self._functions[name] = function + self._functions[compat.as_str(name)] = function # Need a new-enough consumer to support the functions we add to the graph. if self._graph_def_versions.min_consumer < 12: diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py index 48a834392b..7ee2b5b347 100644 --- a/tensorflow/python/framework/smart_cond.py +++ b/tensorflow/python/framework/smart_cond.py @@ -77,11 +77,9 @@ def smart_constant_value(pred): pred_value = pred elif isinstance(pred, ops.Tensor): pred_value = tensor_util.constant_value(pred) - # TODO(skyewm): consider folding this into tensor_util.constant_value when - # _USE_C_API is removed (there may be performance and correctness bugs, so I - # wanted to limit the change hidden behind _USE_C_API). + # TODO(skyewm): consider folding this into tensor_util.constant_value. # pylint: disable=protected-access - if pred_value is None and ops._USE_C_API: + if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py index cee7398974..00759eb611 100644 --- a/tensorflow/python/framework/subscribe.py +++ b/tensorflow/python/framework/subscribe.py @@ -137,12 +137,7 @@ def _subscribe_new(tensor, side_effects, control_cache): # are subscribed at the same time, we remove the control dependency from # the original op only once and we add the dependencies to all the # new identities. - if ops._USE_C_API: # pylint: disable=protected-access - new_control_inputs = consumer_op.control_inputs - else: - # Make a copy so we don't modify the actual control inputs (this is fixed - # in the C API). - new_control_inputs = list(consumer_op.control_inputs) + new_control_inputs = consumer_op.control_inputs if tensor.op in new_control_inputs: new_control_inputs.remove(tensor.op) new_control_inputs.append(out.op) diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index ec0daeaddb..266af56611 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -1409,8 +1409,8 @@ class TestCTC(test.TestCase): np.array([seq_len_0], dtype=np.int32)) # batch_size length vector of negative log probabilities log_prob_truth = np.array([ - 0.584855, # output beam 0 - 0.389139 # output beam 1 + -3.5821197, # output beam 0 + -3.777835 # output beam 1 ], np.float32)[np.newaxis, :] decode_truth = [np.array([1, 0]), np.array([0, 1, 0])] diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index cf6fb44275..9f4019e29c 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -332,6 +332,7 @@ class Sequential(Model): else: name = None build_input_shape = None + layer_configs = config model = cls(name=name) for layer_config in layer_configs: layer = layer_module.deserialize(layer_config, diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 65171acfb6..cff612a8de 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -73,19 +73,27 @@ class StackedRNNCells(Layer): '`state_size` attribute. ' 'received cells:', cells) self.cells = cells + # reverse_state_order determines whether the state size will be in a reverse + # order of the cells' state. User might want to set this to True to keep the + # existing behavior. This is only useful when use RNN(return_state=True) + # since the state will be returned as the same order of state_size. + self.reverse_state_order = kwargs.pop('reverse_state_order', False) + if self.reverse_state_order: + logging.warning('reverse_state_order=True in StackedRNNCells will soon ' + 'be deprecated. Please update the code to work with the ' + 'natural order of states if you reply on the RNN states, ' + 'eg RNN(return_state=True).') super(StackedRNNCells, self).__init__(**kwargs) @property def state_size(self): - # States are a flat list - # in reverse order of the cell stack. - # This allows to preserve the requirement - # `stack.state_size[0] == output_dim`. - # e.g. states of a 2-layer LSTM would be - # `[h2, c2, h1, c1]` + # States are a flat list of the individual cell state size. + # e.g. states of a 2-layer LSTM would be `[h1, c1, h2, c2]`. # (assuming one LSTM has states [h, c]) + # In the case of reverse_state_order=True, the state_size will be + # [h2, c2, h1, c1]. state_size = [] - for cell in self.cells[::-1]: + for cell in self.cells[::-1] if self.reverse_state_order else self.cells: if _is_multiple_state(cell.state_size): state_size += list(cell.state_size) else: @@ -96,15 +104,16 @@ class StackedRNNCells(Layer): def output_size(self): if getattr(self.cells[-1], 'output_size', None) is not None: return self.cells[-1].output_size + elif _is_multiple_state(self.cells[-1].state_size): + return self.cells[-1].state_size[0] else: - return self.state_size[0] + return self.cells[-1].state_size def get_initial_state(self, inputs=None, batch_size=None, dtype=None): - # The init state is in reverse order of cell's initial state since the - # state_size is in reverse order. It is flattened into a list also because - # the state_size is a flattened list. + # The init state is flattened into a list because state_size is a flattened + # list. initial_states = [] - for cell in self.cells[::-1]: + for cell in self.cells[::-1] if self.reverse_state_order else self.cells: get_initial_state_fn = getattr(cell, 'get_initial_state', None) if get_initial_state_fn: initial_states.append(get_initial_state_fn( @@ -118,14 +127,15 @@ class StackedRNNCells(Layer): def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] - for cell in self.cells[::-1]: + for cell in self.cells[::-1] if self.reverse_state_order else self.cells: if _is_multiple_state(cell.state_size): nested_states.append(states[:len(cell.state_size)]) states = states[len(cell.state_size):] else: nested_states.append([states[0]]) states = states[1:] - nested_states = nested_states[::-1] + if self.reverse_state_order: + nested_states = nested_states[::-1] # Call the cells in order and store the returned states. new_nested_states = [] @@ -139,11 +149,12 @@ class StackedRNNCells(Layer): new_nested_states.append(states) # Format the new states as a flat list - # in reverse cell order. - states = [] - for cell_states in new_nested_states[::-1]: - states += cell_states - return inputs, states + new_states = [] + if self.reverse_state_order: + new_nested_states = new_nested_states[::-1] + for cell_states in new_nested_states: + new_states += cell_states + return inputs, new_states @tf_utils.shape_type_conversion def build(self, input_shape): @@ -156,7 +167,9 @@ class StackedRNNCells(Layer): cell.build([input_shape] + constants_shape) else: cell.build(input_shape) - if _is_multiple_state(cell.state_size): + if getattr(cell, 'output_size', None) is not None: + output_dim = cell.output_size + elif _is_multiple_state(cell.state_size): output_dim = cell.state_size[0] else: output_dim = cell.state_size diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index b52bfc05a5..a3861e44d5 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -103,7 +103,8 @@ class RNNTest(test.TestCase): MinimalRNNCell(16, 8), MinimalRNNCell(32, 16)] layer = keras.layers.RNN(cells) - assert layer.cell.state_size == (32, 32, 16, 16, 8, 8) + self.assertEqual(layer.cell.state_size, (8, 8, 16, 16, 32, 32)) + self.assertEqual(layer.cell.output_size, 32) y = layer(x) model = keras.models.Model(x, y) model.compile(optimizer='rmsprop', loss='mse') @@ -551,6 +552,21 @@ class RNNTest(test.TestCase): layer = keras.layers.RNN(cells, return_state=True, return_sequences=True) output_shape = layer.compute_output_shape((None, timesteps, embedding_dim)) expected_output_shape = [(None, timesteps, 6), + (None, 3), + (None, 3), + (None, 6), + (None, 6)] + self.assertEqual( + [tuple(o.as_list()) for o in output_shape], + expected_output_shape) + + # Test reverse_state_order = True for stacked cell. + stacked_cell = keras.layers.StackedRNNCells( + cells, reverse_state_order=True) + layer = keras.layers.RNN( + stacked_cell, return_state=True, return_sequences=True) + output_shape = layer.compute_output_shape((None, timesteps, embedding_dim)) + expected_output_shape = [(None, timesteps, 6), (None, 6), (None, 6), (None, 3), diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 6bc256d2ec..39b6042597 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -33,6 +33,7 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.util.tf_export import tf_export # API entries importable from `keras.models`: Model = training.Model # pylint: disable=invalid-name @@ -226,6 +227,7 @@ def _clone_sequential_model(model, input_tensors=None): return Sequential(layers=[input_layer] + layers, name=model.name) +@tf_export('keras.models.clone_model') def clone_model(model, input_tensors=None): """Clone any `Model` instance. diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index b9910133d8..0dc3c53bc0 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -20,9 +20,9 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 @@ -158,7 +158,7 @@ class CondV2Test(test.TestCase): def true_fn(): - @function.Defun() + @function.defun def fn(): return x * y * 2.0 @@ -172,6 +172,8 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testNestedDefunInCond(self): + self.skipTest("b/110550782") + x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") @@ -180,10 +182,10 @@ class CondV2Test(test.TestCase): def false_fn(): - @function.Defun() + @function.defun def fn(): - @function.Defun() + @function.defun def nested_fn(): return x * y * 2.0 @@ -196,18 +198,20 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testDoubleNestedDefunInCond(self): + self.skipTest("b/110550782") + x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") def true_fn(): - @function.Defun() + @function.defun def fn(): - @function.Defun() + @function.defun def nested_fn(): - @function.Defun() + @function.defun def nested_nested_fn(): return x * y * 2.0 @@ -368,7 +372,7 @@ class CondV2Test(test.TestCase): pred_outer, true_fn, false_fn, name="outer_cond") # Compute grads inside a Defun. - @function.Defun() + @function.defun def nesting_fn(): return gradients_impl.gradients(cond_outer, [x, y]) @@ -426,10 +430,10 @@ class CondV2Test(test.TestCase): pred_outer, true_fn, false_fn, name="outer_cond") # Compute grads inside a Defun. - @function.Defun() + @function.defun def nesting_fn(): - @function.Defun() + @function.defun def inner_nesting_fn(): return gradients_impl.gradients(cond_outer, [x, y]) @@ -464,6 +468,7 @@ class CondV2Test(test.TestCase): }), [5., 0.]) def testBuildCondAndGradientInsideDefun(self): + self.skipTest("b/110550782") def build_graph(): pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") @@ -472,7 +477,7 @@ class CondV2Test(test.TestCase): y = constant_op.constant(2.0, name="y") # Build cond and its gradient inside a Defun. - @function.Defun() + @function.defun def fn(): def true_fn(): @@ -718,6 +723,7 @@ class CondV2ContainerTest(test.TestCase): Make sure the containers are set correctly for both variable creation (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) """ + self.skipTest("b/113048653") with ops.Graph().as_default() as g: with self.test_session(graph=g): @@ -795,6 +801,7 @@ class CondV2ContainerTest(test.TestCase): class CondV2ColocationGroupAndDeviceTest(test.TestCase): def testColocateWithBeforeCond(self): + self.skipTest("b/112414483") with ops.Graph().as_default() as g: with self.test_session(graph=g): @@ -819,6 +826,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) def testColocateWithInAndOutOfCond(self): + self.skipTest("b/112414483") with ops.Graph().as_default() as g: with self.test_session(graph=g): @@ -866,6 +874,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): self.assertTrue(len(run_metadata.partition_graphs) >= 2) def testDeviceBeforeCond(self): + self.skipTest("b/112166045") with ops.Graph().as_default() as g: with self.test_session(graph=g): def fn(): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 5e0447e4ff..4a3e767f4d 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -32,6 +32,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import function as _ # pylint: disable=unused-import from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl diff --git a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py index e1920eb568..41ae0b456f 100644 --- a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py +++ b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py @@ -188,11 +188,11 @@ class CTCGreedyDecoderTest(test.TestCase): ], dtype=np.float32) # Add arbitrary offset - this is fine - input_log_prob_matrix_0 = np.log(input_prob_matrix_0) + 2.0 + input_prob_matrix_0 = input_prob_matrix_0 + 2.0 # len max_time_steps array of batch_size x depth matrices inputs = ([ - input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0) + input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0) ] # Pad to max_time_steps = 8 + 2 * [np.zeros( (1, depth), dtype=np.float32)]) @@ -200,11 +200,11 @@ class CTCGreedyDecoderTest(test.TestCase): # batch_size length vector of sequence_lengths seq_lens = np.array([seq_len_0], dtype=np.int32) - # batch_size length vector of negative log probabilities + # batch_size length vector of log probabilities log_prob_truth = np.array( [ - 0.584855, # output beam 0 - 0.389139 # output beam 1 + -5.811451, # output beam 0 + -6.63339 # output beam 1 ], np.float32)[np.newaxis, :] @@ -215,11 +215,11 @@ class CTCGreedyDecoderTest(test.TestCase): [[0, 0], [0, 1]], dtype=np.int64), np.array( [1, 0], dtype=np.int64), np.array( [1, 2], dtype=np.int64)), - # beam 1, batch 0, three outputs decoded + # beam 1, batch 0, one output decoded (np.array( - [[0, 0], [0, 1], [0, 2]], dtype=np.int64), np.array( - [0, 1, 0], dtype=np.int64), np.array( - [1, 3], dtype=np.int64)), + [[0, 0]], dtype=np.int64), np.array( + [1], dtype=np.int64), np.array( + [1, 1], dtype=np.int64)), ] # Test correct decoding. diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 9eaafb4435..b167278984 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -142,7 +142,7 @@ class MatMulStatsTest(test_lib.TestCase): for op in g.get_operations(): flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value if op.name == "MatMul": - self.assertEqual(6975, flops) + self.assertEqual(7200, flops) def testTransposedStatistics(self): g = ops.Graph() @@ -153,7 +153,7 @@ class MatMulStatsTest(test_lib.TestCase): for op in g.get_operations(): flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value if op.name == "MatMul": - self.assertEqual(6975, flops) + self.assertEqual(7200, flops) try: diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index 1d0c2dceba..15d5702252 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -27,15 +27,12 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib @@ -549,6 +546,32 @@ class PartitionedVariablesTestCase(test.TestCase): partitioned_variables.create_partitioned_variables( [10, 43], [1, 50], rnd.initialized_value()) + def testControlDepsNone(self): + with self.test_session() as session: + c = constant_op.constant(1.0) + with ops.control_dependencies([c]): + # d get the control dependency. + d = constant_op.constant(2.0) + # Partitioned variables do not. + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner(4)) + + ops_before_read = session.graph.get_operations() + var_x.as_tensor() # Caches the ops for subsequent reads. + reading_ops = [ + op for op in session.graph.get_operations() + if op not in ops_before_read + ] + + self.assertEqual([c.op], d.op.control_inputs) + # Tests that no control dependencies are added to reading a partitioned + # variable which is similar to reading a variable. + for op in reading_ops: + self.assertEqual([], op.control_inputs) + def testConcat(self): with self.test_session() as session: var_x = variable_scope.get_variable( @@ -574,57 +597,6 @@ class PartitionedVariablesTestCase(test.TestCase): variables.global_variables_initializer().run() self.assertAllClose(value.eval(), var_x.as_tensor().eval()) - def testVariableCreationInALoop(self): - """Tests the variable created inside a loop can be used outside the loop.""" - with self.test_session(): - with variable_scope.variable_scope("ascope") as scope: - def Body(i, _): - var_x = variable_scope.get_variable( - "x", - shape=[2], - initializer=init_ops.ones_initializer(), - partitioner=partitioned_variables.variable_axis_size_partitioner( - 4)) - return (i + 1, var_x.as_tensor()) - - cond = lambda i, _: i < 2 - _, x = control_flow_ops.while_loop( - cond, Body, (0, constant_op.constant([7, 8], dtypes.float32))) - variables.global_variables_initializer().run() - self.assertAllClose([1.0, 1.0], x.eval()) - - scope.reuse_variables() - var_x = variable_scope.get_variable( - "x", - shape=[2], - initializer=init_ops.ones_initializer(), - partitioner=partitioned_variables.variable_axis_size_partitioner(4)) - - self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval()) - - def testReadInWhileLoop(self): - """Tests the value is current (not cached) when read within a loop.""" - with self.test_session(): - var_x = variable_scope.get_variable( - "x", - shape=[2], - initializer=init_ops.ones_initializer(), - partitioner=partitioned_variables.variable_axis_size_partitioner(4)) - - def Body(i, _): - # Use a SGD step to update the variable's value. - loss = math_ops.reduce_sum(var_x) - optimizer = gradient_descent.GradientDescentOptimizer(1.0) - minimize = optimizer.minimize(loss * 0.7) - with ops.control_dependencies([minimize]): - return (i + 1, var_x.as_tensor()) - - cond = lambda i, _: i < 2 - _, x = control_flow_ops.while_loop( - cond, Body, (0, constant_op.constant([7, 8], dtypes.float32))) - variables.global_variables_initializer().run() - self.assertAllClose([-0.4, -0.4], x.eval()) - def testMetaGraphSaveLoad(self): save_prefix = os.path.join(self.get_temp_dir(), "ckpt") save_graph = ops.Graph() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index c4f200a22e..78f2993d27 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -441,11 +441,11 @@ class RNNTest(test.TestCase): cell, inputs, dtype=dtypes.float32) self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) self.assertEqual(len(state), 4) - self.assertEqual(state[0].shape.as_list(), [None, output_shape]) - self.assertEqual(state[1].shape.as_list(), [None, output_shape]) - self.assertEqual(state[2].shape.as_list(), [None, 2 * output_shape]) - self.assertEqual(state[3].shape.as_list(), [None, 2 * output_shape]) - loss = losses.softmax_cross_entropy(predict, state[0]) + self.assertEqual(state[0].shape.as_list(), [None, 2 * output_shape]) + self.assertEqual(state[1].shape.as_list(), [None, 2 * output_shape]) + self.assertEqual(state[2].shape.as_list(), [None, output_shape]) + self.assertEqual(state[3].shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state[2]) train_op = training.GradientDescentOptimizer(0.001).minimize(loss) sess.run([variables_lib.global_variables_initializer()]) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 76173e0f30..75a1a53eb7 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -24,7 +24,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import -from tensorflow.python.framework import function +from tensorflow.python.eager import function from tensorflow.python.framework import function_def_to_graph from tensorflow.python.ops import gradients_impl diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index b3dacff6d6..c4e9c982b5 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -27,14 +27,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python import pywrap_tensorflow as c_api -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_functional_ops -from tensorflow.python.util import compat # The following modules cannot be imported directly because they cause circular @@ -57,46 +56,27 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): name = "cond" with ops.name_scope(name) as scope: - # Identify if there is a caller device, & get the innermost if possible. - # pylint: disable=protected-access - device_funcs = ops.get_default_graph()._device_functions_outer_to_inner - caller_device = device_funcs[-1] if device_funcs else None - - caller_colocation_stack = ops.get_default_graph()._colocation_stack - caller_container = ops.get_default_graph()._container - caller_collection_ref = ops.get_default_graph()._collections - with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() - while isinstance(graph, _function._FuncGraph): - graph = graph._outer_graph + while isinstance(graph, _function.FuncGraph): + graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) - # pylint: enable=protected-access + true_graph = _function.func_graph_from_py_func( - true_fn, [], [], - name=true_name, - device=caller_device, - colocation_stack=caller_colocation_stack, - collections_ref=caller_collection_ref, - container=caller_container) + true_name, true_fn, [], {}) false_graph = _function.func_graph_from_py_func( - false_fn, [], [], - name=false_name, - device=caller_device, - colocation_stack=caller_colocation_stack, - collections_ref=caller_collection_ref, - container=caller_container) + false_name, false_fn, [], {}) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, - true_graph.extra_inputs, - false_graph.extra_inputs) + true_graph.external_captures, + false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. @@ -148,8 +128,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name true_graph, false_graph = _get_func_graphs(op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. - assert true_graph._outer_graph == op.graph - assert false_graph._outer_graph == op.graph + assert true_graph.outer_graph == op.graph + assert false_graph.outer_graph == op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass @@ -164,14 +144,13 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. - true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) - false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) + true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) + false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, - true_grad_extra_inputs, - false_grad_extra_inputs) + true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. @@ -211,8 +190,8 @@ def _get_func_graphs(if_op): """ def _get_func_graph_for_branch(branch_name): """Generates and returns a _FuncGraph for the given branch.""" - extra_inputs = if_op.inputs[1:] # First input is pred. - input_shapes = [t.shape for t in extra_inputs] + inputs = if_op.inputs[1:] # First input is pred. + input_shapes = [t.shape for t in inputs] func_name = if_op.get_attr(branch_name).name fdef = if_op.graph._get_function(func_name).definition # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. @@ -224,9 +203,8 @@ def _get_func_graphs(if_op): with if_op.graph.as_default(): func_graph = _function_def_to_graph.function_def_to_graph( fdef, input_shapes) - func_graph.extra_inputs = extra_inputs - func_graph.extra_args = func_graph.inputs - func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + func_graph.captures = collections.OrderedDict(zip(inputs, + func_graph.inputs)) # Set the if op so that the gradient code can use it. func_graph._if = if_op return func_graph @@ -282,12 +260,12 @@ def _grad_fn(func_graph, grads): def _create_grad_func(func_graph, grads, name): """Returns the _FuncGraph representation of _grad_fn.""" - return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), - [], [], name) + return _function.func_graph_from_py_func( + name, lambda: _grad_fn(func_graph, grads), [], {}) def _resolve_grad_inputs(cond_graph, grad_graph): - """Returns the tensors to pass as `extra_inputs` to `grad_graph`. + """Returns the tensors to pass as inputs to `grad_graph`. The `grad_graph` may have external references to 1. Its outer graph containing the input gradients. These references are kept @@ -305,10 +283,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph): Returns: A list of inputs tensors to be passed to grad_graph. """ - new_extra_inputs = [] + new_inputs = [] - for t in grad_graph.extra_inputs: - if t.graph != grad_graph._outer_graph: + for t in grad_graph.external_captures: + if t.graph != grad_graph.outer_graph: # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this # tensor to the least common ancestor of the `cond_graph` and # `grad_graph` so that it is "in-scope" for `grad_graph`. @@ -316,19 +294,19 @@ def _resolve_grad_inputs(cond_graph, grad_graph): # common ancestor once and re-use. assert _is_ancestor(cond_graph, t.graph) while not _is_ancestor(grad_graph, t.graph): - assert isinstance(t.graph, _function._FuncGraph) - if t in t.graph.extra_args: - # TODO(srbs): Consider building a map of extra_args -> extra_inputs. - # instead of searching for `t` twice. - t = t.graph.extra_inputs[t.graph.extra_args.index(t)] + assert isinstance(t.graph, _function.FuncGraph) + if t in t.graph.internal_captures: + # TODO(srbs): Consider building a map of internal_captures -> + # external_captures instead of searching for `t` twice. + t = t.graph.external_captures[t.graph.internal_captures.index(t)] else: # Note: All intermediate tensors are output by the If op. # TODO(srbs): .index() calls may be expensive. Optimize. t = t.graph._if.outputs[t.graph.outputs.index(t)] assert _is_ancestor(grad_graph, t.graph) - new_extra_inputs.append(t) + new_inputs.append(t) - return new_extra_inputs + return new_inputs def _create_new_tf_function(func_graph): @@ -340,26 +318,9 @@ def _create_new_tf_function(func_graph): Returns: The name of the new TF_Function. """ - c_func = c_api.TF_GraphToFunction_wrapper( - func_graph._c_graph, - compat.as_str(func_graph.name), - False, # append_hash_to_fn_name - None, # opers - [t._as_tf_output() for t in func_graph.inputs], - [t._as_tf_output() for t in func_graph.outputs], - [], - None, # opts - None) # description - _ = c_api_util.ScopedTFFunction(c_func) - - # TODO(b/109833212): this sucks, we're serializing the TF_Function*, - # deserializing it into a Python FunctionDef, then reserializing it to create - # a new TF_Function that we add to the graph. - fdef = _function.function_def_from_tf_function(c_func) - defined_func = _function._from_definition(fdef) - defined_func._sub_functions = func_graph._functions - defined_func.add_to_graph(func_graph._outer_graph) - + func = _function._EagerDefinedFunction( + func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) + func.add_to_graph(func_graph.outer_graph) return func_graph.name @@ -421,21 +382,20 @@ def _pad_params(true_graph, false_graph, true_params, false_params): return new_true_params, new_false_inputs -def _make_inputs_match(true_graph, false_graph, true_extra_inputs, - false_extra_inputs): +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): """Modifies true_graph and false_graph so they have the same input signature. This method reorders and/or adds parameters to true_graph and false_graph so - they have the same input signature, and updates the 'inputs', 'extra_inputs', - and '_captured' fields of both graphs accordingly. It uses the input tensors - from the outer graph to avoid duplicating shared arguments. + they have the same input signature, and updates the 'inputs' and 'captured' + fields of both graphs accordingly. It uses the input tensors from the outer + graph to avoid duplicating shared arguments. Args: true_graph: function._FuncGraph false_graph: function._FuncGraph - true_extra_inputs: a list of Tensors in the outer graph. The inputs for + true_inputs: a list of Tensors in the outer graph. The inputs for true_graph. - false_extra_inputs: a list of Tensors in the outer graph. The inputs for + false_inputs: a list of Tensors in the outer graph. The inputs for false_graph. Returns: @@ -444,12 +404,12 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs, false_inputs. """ shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( - true_extra_inputs, false_extra_inputs) + true_inputs, false_inputs) new_inputs = shared_inputs + true_only_inputs + false_only_inputs - true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs)) - false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs)) + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) true_graph.inputs = ( [true_input_to_param[t] for t in shared_inputs] + @@ -462,14 +422,10 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs, [false_input_to_param[t] for t in false_only_inputs]) # Rewrite the _FuncGraphs' state to reflect the new inputs. - true_graph.extra_inputs = new_inputs - false_graph.extra_inputs = new_inputs - - true_graph.extra_args = true_graph.inputs - false_graph.extra_args = false_graph.inputs - - true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) - false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + true_graph.captures = collections.OrderedDict(zip(new_inputs, + true_graph.inputs)) + false_graph.captures = collections.OrderedDict(zip(new_inputs, + false_graph.inputs)) return new_inputs @@ -506,10 +462,10 @@ def _get_grad_fn_name(func_graph): counter = 1 has_conflict = True while has_conflict: - curr_graph = func_graph._outer_graph + curr_graph = func_graph.outer_graph has_conflict = curr_graph._is_function(name) - while not has_conflict and isinstance(curr_graph, _function._FuncGraph): - curr_graph = curr_graph._outer_graph + while not has_conflict and isinstance(curr_graph, _function.FuncGraph): + curr_graph = curr_graph.outer_graph has_conflict = curr_graph._is_function(name) if has_conflict: name = "%s_%s" % (base_name, counter) @@ -534,6 +490,6 @@ def _check_same_outputs(true_graph, false_graph): def _is_ancestor(graph, maybe_ancestor): if maybe_ancestor == graph: return True - if isinstance(graph, _function._FuncGraph): - return _is_ancestor(graph._outer_graph, maybe_ancestor) + if isinstance(graph, _function.FuncGraph): + return _is_ancestor(graph.outer_graph, maybe_ancestor) return False diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index d1095c8954..e3c1aa3d5a 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1966,8 +1966,12 @@ def cond(pred, `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and `false_fn` must have the same non-zero number and type of outputs. - Note that the conditional execution applies only to the operations defined in - `true_fn` and `false_fn`. Consider the following simple program: + **WARNING**: Any Tensors or Operations created outside of `true_fn` and + `false_fn` will be executed regardless of which branch is selected at runtime. + + Although this behavior is consistent with the dataflow model of TensorFlow, + it has frequently surprised users who expected a lazier semantics. + Consider the following simple program: ```python z = tf.multiply(a, b) @@ -1978,8 +1982,6 @@ def cond(pred, operation will not be executed. Since `z` is needed for at least one branch of the `cond`, the `tf.multiply` operation is always executed, unconditionally. - Although this behavior is consistent with the dataflow model of TensorFlow, - it has occasionally surprised some users who expected a lazier semantics. Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the call to `cond`, and not at all during `Session.run()`). `cond` diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 70b5e9b4b7..9b0ab00c7a 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -618,7 +618,7 @@ def cast(x, dtype, name=None): """Casts a tensor to a new type. The operation casts `x` (in case of `Tensor`) or `x.values` - (in case of `SparseTensor`) to `dtype`. + (in case of `SparseTensor` or `IndexedSlices`) to `dtype`. For example: @@ -637,15 +637,16 @@ def cast(x, dtype, name=None): behavior of numpy. Args: - x: A `Tensor` or `SparseTensor` of numeric type. It could be - `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, - `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. - dtype: The destination type. The list of supported dtypes is the same - as `x`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could + be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, + `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`, + `bfloat16`. + dtype: The destination type. The list of supported dtypes is the same as + `x`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` and + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and same type as `dtype`. Raises: @@ -659,6 +660,9 @@ def cast(x, dtype, name=None): if isinstance(x, sparse_tensor.SparseTensor): values_cast = cast(x.values, base_type, name=name) x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape) + elif isinstance(x, ops.IndexedSlices): + values_cast = cast(x.values, base_type, name=name) + x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape) else: # TODO(josh11b): If x is not already a Tensor, we could return # ops.convert_to_tensor(x, dtype=dtype, ...) here, but that @@ -711,11 +715,12 @@ def to_float(x, name="ToFloat"): """Casts a tensor to type `float32`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `float32`. Raises: TypeError: If `x` cannot be cast to the `float32`. @@ -728,11 +733,12 @@ def to_double(x, name="ToDouble"): """Casts a tensor to type `float64`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `float64`. Raises: TypeError: If `x` cannot be cast to the `float64`. @@ -745,11 +751,12 @@ def to_int32(x, name="ToInt32"): """Casts a tensor to type `int32`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `int32`. Raises: TypeError: If `x` cannot be cast to the `int32`. @@ -762,11 +769,12 @@ def to_int64(x, name="ToInt64"): """Casts a tensor to type `int64`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `int64`. Raises: TypeError: If `x` cannot be cast to the `int64`. @@ -779,11 +787,12 @@ def to_bfloat16(x, name="ToBFloat16"): """Casts a tensor to type `bfloat16`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `bfloat16`. Raises: TypeError: If `x` cannot be cast to the `bfloat16`. @@ -796,11 +805,12 @@ def to_complex64(x, name="ToComplex64"): """Casts a tensor to type `complex64`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `complex64`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `complex64`. Raises: TypeError: If `x` cannot be cast to the `complex64`. @@ -813,11 +823,12 @@ def to_complex128(x, name="ToComplex128"): """Casts a tensor to type `complex128`. Args: - x: A `Tensor` or `SparseTensor`. + x: A `Tensor` or `SparseTensor` or `IndexedSlices`. name: A name for the operation (optional). Returns: - A `Tensor` or `SparseTensor` with same shape as `x` with type `complex128`. + A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with + type `complex128`. Raises: TypeError: If `x` cannot be cast to the `complex128`. @@ -2061,7 +2072,7 @@ def _calc_mat_mul_flops(graph, node): output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) output_shape.assert_is_fully_defined() output_count = np.prod(output_shape.as_list()) - return ops.OpStats("flops", ((2 * k - 1) * output_count)) + return ops.OpStats("flops", (k * output_count * 2)) def _as_indexed_slices(x, optimize=True): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index baba5d4093..4800352ac2 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -355,6 +355,15 @@ class ResourceVariable(variables.RefVariable): raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) + if isinstance(initial_value, ops.Tensor) and hasattr( + initial_value, "graph") and initial_value.graph.building_function: + raise ValueError("Tensor-typed variable initializers must either be " + "wrapped in an init_scope or callable " + "(e.g., `tf.Variable(lambda : " + "tf.truncated_normal([10, 40]))`) when building " + "functions. Please file a feature request if this " + "restriction inconveniences you.") + if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index d990386b9a..38ce5236e3 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -96,6 +96,60 @@ def _make_int64_tensor(value, name): return math_ops.cast(value, dtypes.int64) +@tf_export("sparse.expand_dims") +def sparse_expand_dims(sp_input, axis=None, name=None): + """Inserts a dimension of 1 into a tensor's shape. + + Given a tensor `sp_input`, this operation inserts a dimension of 1 at the + dimension index `axis` of `sp_input`'s shape. The dimension index `axis` + starts at zero; if you specify a negative number for `axis` it is counted + backwards from the end. + + Args: + sp_input: A `SparseTensor`. + axis: 0-D (scalar). Specifies the dimension index at which to expand the + shape of `input`. Must be in the range `[-rank(sp_input) - 1, + rank(sp_input)]`. + name: The name of the output `SparseTensor`. + + Returns: + A `SparseTensor` with the same data as `sp_input`, but its shape has an + additional dimension of size 1 added. + """ + rank = sp_input.dense_shape.get_shape()[0] + axis = -1 if axis is None else axis + + with ops.name_scope(name, default_name="expand_dims", values=[sp_input]): + if isinstance(axis, compat.integral_types): + axis = ops.convert_to_tensor(axis, name="axis", dtype=dtypes.int32) + elif not isinstance(axis, ops.Tensor): + raise TypeError("axis must be an integer value in range [-rank(sp_input)" + " - 1, rank(sp_input)]") + + # Convert axis to a positive value if it is negative. + axis = array_ops.where(axis >= 0, axis, axis + rank + 1) + + # Create the new column of indices for the sparse tensor by slicing + # the indices and inserting a new column of indices for the new dimension. + column_size = array_ops.shape(sp_input.indices)[0] + new_index = array_ops.zeros([column_size, 1], dtype=dtypes.int64) + indices_before = array_ops.slice(sp_input.indices, [0, 0], [-1, axis]) + indices_after = array_ops.slice(sp_input.indices, [0, axis], [-1, -1]) + indices = array_ops.concat( + [indices_before, new_index, indices_after], axis=1) + + # Create the new dense shape by splicing the tensor [1] in the correct + # dimension of the existing shape. + shape_before = array_ops.slice(sp_input.dense_shape, [0], [axis]) + shape_after = array_ops.slice(sp_input.dense_shape, [axis], [-1]) + new_shape = ops.convert_to_tensor([1], name="new_shape", dtype=dtypes.int64) + shape = array_ops.concat([shape_before, new_shape, shape_after], axis=0) + + # Create the output sparse tensor. + return sparse_tensor.SparseTensor( + indices=indices, values=sp_input.values, dense_shape=shape) + + @tf_export("sparse.eye") def sparse_eye(num_rows, num_columns=None, diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py index b10c3c2187..4ee1569249 100644 --- a/tensorflow/python/ops/sparse_ops_test.py +++ b/tensorflow/python/ops/sparse_ops_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import googletest @@ -45,5 +47,35 @@ class SparseOpsTest(test_util.TensorFlowTestCase): test_one(n, m, True) test_one(n, m, False) + def testSparseExpandDims(self): + for rank in range(1, 4): + # Create a dummy input. When rank=3, shape=[2, 4, 6]. + shape = np.arange(1, rank + 1) * 2 + before = np.arange(np.prod(shape)).reshape(shape) + + # Make entries sparse. + before *= np.random.binomial(1, .2, before.shape) + dense_shape = before.shape + indices = np.array(np.where(before)).T + values = before[before != 0] + + # Try every possible valid value of axis. + for axis in range(-rank - 1, rank): + expected_after = np.expand_dims(before, axis) + + for axis_as_tensor in [False, True]: + dense_shape_t = constant_op.constant(dense_shape, dtype=dtypes.int64) + indices_t = constant_op.constant(indices) + values_t = constant_op.constant(values) + before_t = sparse_tensor.SparseTensor( + indices=indices_t, values=values_t, dense_shape=dense_shape_t) + + if axis_as_tensor: + axis = constant_op.constant(axis) + + s = sparse_ops.sparse_expand_dims(before_t, axis) + d = sparse_ops.sparse_to_dense(s.indices, s.dense_shape, s.values) + self.assertAllEqual(self.evaluate(d), expected_after) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 571265665b..f7da3f7d64 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -2336,10 +2336,15 @@ class PartitionedVariable(object): def as_tensor(self): """Returns the overall concatenated value as a `Tensor`. + The returned tensor will not inherit the control dependencies from the scope + where the value is used, which is similar to getting the value of + `Variable`. + Returns: `Tensor` containing the concatenated value. """ - return self._concat() + with ops.control_dependencies(None): + return self._concat() @staticmethod def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 85f2904318..b7aa8264b0 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -510,7 +510,10 @@ class CheckpointManager(object): max_to_keep: An integer, the number of checkpoints to keep. Unless preserved by `keep_checkpoint_every_n_hours`, checkpoints will be deleted from the active set, oldest first, until only `max_to_keep` - checkpoints remain. + checkpoints remain. If `None`, no checkpoints are deleted and everything + stays in the active set. Note that `max_to_keep=None` will keep all + checkpoint paths in memory and in the checkpoint state protocol buffer + on disk. keep_checkpoint_every_n_hours: Upon removal from the active set, a checkpoint will be preserved if it has been at least `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The @@ -521,9 +524,10 @@ class CheckpointManager(object): """ self._checkpoint = checkpoint self._save_counter_assign = None - if not max_to_keep or max_to_keep < 0: + if max_to_keep is not None and max_to_keep <= 0: raise ValueError( - "Expected a positive integer for `max_to_max_to_keep`, got %d." + ("Expected a positive integer or `None` for `max_to_max_to_keep`, " + "got %d.") % (max_to_keep,)) self._max_to_keep = max_to_keep self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours @@ -586,6 +590,10 @@ class CheckpointManager(object): def _sweep(self): """Deletes or preserves managed checkpoints.""" + if not self._max_to_keep: + # Does not update self._last_preserved_timestamp, since everything is kept + # in the active set. + return while len(self._maybe_delete) > self._max_to_keep: filename, timestamp = self._maybe_delete.popitem(last=False) # Even if we're keeping this checkpoint due to diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py index 22c2cc678a..d7162265e6 100644 --- a/tensorflow/python/training/checkpoint_management_test.py +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -26,6 +26,7 @@ import tempfile from google.protobuf import text_format from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util @@ -333,6 +334,49 @@ class CheckpointManagerTest(test.TestCase): self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) @test_util.run_in_graph_and_eager_modes + def testKeepAll(self): + checkpoint = util.Checkpoint() + directory = os.path.join( + self.get_temp_dir(), + # Avoid sharing directories between eager and graph + # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories + str(context.executing_eagerly())) + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=None) + first_path = manager.save() + second_path = manager.save() + third_path = manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) + self.assertEqual(third_path, manager.latest_checkpoint) + self.assertEqual([first_path, second_path, third_path], + manager.checkpoints) + del manager + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=None) + fourth_path = manager.save() + self.assertEqual([first_path, second_path, third_path, fourth_path], + manager.checkpoints) + del manager + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=3) + self.assertEqual([first_path, second_path, third_path, fourth_path], + manager.checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) + fifth_path = manager.save() + self.assertEqual([third_path, fourth_path, fifth_path], + manager.checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) + + @test_util.run_in_graph_and_eager_modes @test.mock.patch.object(checkpoint_management, "time") def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index ecadc56871..697b44c3ff 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -384,8 +384,8 @@ class CheckpointingTests(test.TestCase): saver = saver_lib.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") - self.evaluate(v.non_dep_variable.assign(42.)) with self.test_session() as sess: + self.evaluate(v.non_dep_variable.assign(42.)) save_path = saver.save(sess, prefix) self.evaluate(v.non_dep_variable.assign(43.)) self.evaluate(v.mirrored.assign(44.)) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 4b91d1e963..177a7ddfa5 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -363,10 +363,12 @@ class ExponentialMovingAverage(object): `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to `tf.global_variables()`. - Returns an op that updates all shadow variables as described above. + Returns an op that updates all shadow variables from the current value of + their associated variables. - Note that `apply()` can be called multiple times with different lists of - variables. + Note that `apply()` can be called multiple times. When eager execution is + enabled each call to apply will update the variables once, so this needs to + be called in a loop. Args: var_list: A list of Variable or Tensor objects. The variables @@ -389,31 +391,30 @@ class ExponentialMovingAverage(object): dtypes.float64]: raise TypeError("The variables must be half, float, or double: %s" % var.name) - if var in self._averages: - raise ValueError("Moving average already computed for: %s" % var.name) - # For variables: to lower communication bandwidth across devices we keep - # the moving averages on the same device as the variables. For other - # tensors, we rely on the existing device allocation mechanism. - with ops.init_scope(): - if isinstance(var, variables.Variable): - avg = slot_creator.create_slot(var, - var.initialized_value(), - self.name, - colocate_with_primary=True) - # NOTE(mrry): We only add `tf.Variable` objects to the - # `MOVING_AVERAGE_VARIABLES` collection. - ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) - else: - avg = slot_creator.create_zeros_slot( - var, - self.name, - colocate_with_primary=(var.op.type in ["Variable", - "VariableV2", - "VarHandleOp"])) - if self._zero_debias: - zero_debias_true.add(avg) - self._averages[var] = avg + if var not in self._averages: + # For variables: to lower communication bandwidth across devices we keep + # the moving averages on the same device as the variables. For other + # tensors, we rely on the existing device allocation mechanism. + with ops.init_scope(): + if isinstance(var, variables.Variable): + avg = slot_creator.create_slot(var, + var.initialized_value(), + self.name, + colocate_with_primary=True) + # NOTE(mrry): We only add `tf.Variable` objects to the + # `MOVING_AVERAGE_VARIABLES` collection. + ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) + else: + avg = slot_creator.create_zeros_slot( + var, + self.name, + colocate_with_primary=(var.op.type in ["Variable", + "VariableV2", + "VarHandleOp"])) + if self._zero_debias: + zero_debias_true.add(avg) + self._averages[var] = avg with ops.name_scope(self.name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 3e85e6bfa7..fdb8d795c3 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import variable_scope @@ -254,6 +256,25 @@ class ExponentialMovingAverageTest(test.TestCase): self.assertEqual(1, sess.run(v0)) self.assertEqual([17.5], sess.run(v1_avg)) + @test_util.run_in_graph_and_eager_modes + def testBasicEager(self): + v0 = variables.Variable(1.0) + v1 = variables.Variable(2.0) + + ema = moving_averages.ExponentialMovingAverage(0.25) + op = ema.apply([v0, v1]) + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(op) + + self.evaluate(v0.assign(2.0)) + self.evaluate(v1.assign(4.0)) + + self.evaluate(ema.apply([v0, v1])) + + self.assertAllEqual(self.evaluate(ema.average(v0)), 1.75) + self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5) + def averageVariablesNamesHelper(self, zero_debias): with self.test_session(): v0 = variables.Variable(10.0, name="v0") diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b46095d458..f5b2a22327 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -2853,8 +2853,8 @@ class CheckpointableCompatibilityTests(test.TestCase): saver = saver_module.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") - self.evaluate(v.non_dep_variable.assign(42.)) with self.test_session() as sess: + self.evaluate(v.non_dep_variable.assign(42.)) save_path = saver.save(sess, prefix) self.evaluate(v.non_dep_variable.assign(43.)) saver.restore(sess, save_path) diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index 2be4dbb283..a5ac430ce7 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -136,11 +136,14 @@ class api_export(object): # pylint: disable=invalid-name has no effect on exporting a constant. api_name: Name of the API you want to generate (e.g. `tensorflow` or `estimator`). Default is `tensorflow`. + allow_multiple_exports: Allow symbol to be exported multiple time under + different names. """ self._names = args self._names_v1 = kwargs.get('v1', args) self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) + self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False) def __call__(self, func): """Calls this decorator. @@ -173,9 +176,10 @@ class api_export(object): # pylint: disable=invalid-name # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. if api_names_attr in func.__dict__: - raise SymbolAlreadyExposedError( - 'Symbol %s is already exposed as %s.' % - (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access + if not self._allow_multiple_exports: + raise SymbolAlreadyExposedError( + 'Symbol %s is already exposed as %s.' % + (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access setattr(func, api_names_attr, names) def export_constant(self, module_name, name): @@ -213,4 +217,5 @@ class api_export(object): # pylint: disable=invalid-name tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) -estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME) +estimator_export = functools.partial( + api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True) diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h index 3ef8deb72e..d78bbfd425 100644 --- a/tensorflow/stream_executor/lib/env.h +++ b/tensorflow/stream_executor/lib/env.h @@ -32,7 +32,7 @@ inline Status FileExists(const string& filename) { } inline Status FileExists(const port::StringPiece& filename) { - return Env::Default()->FileExists(std::string(filename)); + return Env::Default()->FileExists(string(filename)); } } // namespace port diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc index 58a862206c..3d3da103e1 100644 --- a/tensorflow/stream_executor/lib/path.cc +++ b/tensorflow/stream_executor/lib/path.cc @@ -33,7 +33,7 @@ string JoinPathImpl(std::initializer_list<port::StringPiece> paths) { if (path.empty()) continue; if (result.empty()) { - result = std::string(path); + result = string(path); continue; } diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h index b02fe4f56f..e77dfcef76 100644 --- a/tensorflow/stream_executor/lib/str_util.h +++ b/tensorflow/stream_executor/lib/str_util.h @@ -31,7 +31,7 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix) if (tensorflow::str_util::EndsWith(str, suffix)) { str.remove_suffix(suffix.size()); } - return std::string(str); + return string(str); } using tensorflow::str_util::Lowercase; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt index bf1f94b6ae..269e18a0a7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt @@ -96,7 +96,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "replace" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt index 8ba0e7480b..7ad4a32d43 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt @@ -9,6 +9,10 @@ tf_module { mtype: "<type \'type\'>" } member_method { + name: "clone_model" + argspec: "args=[\'model\', \'input_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { name: "load_model" argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt index 3f54bc33e7..ba9e651b34 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt @@ -9,6 +9,10 @@ tf_module { argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " } member_method { + name: "expand_dims" + argspec: "args=[\'sp_input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { name: "eye" argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], " } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt index bf1f94b6ae..269e18a0a7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt @@ -96,7 +96,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "replace" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt index 8ba0e7480b..7ad4a32d43 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt @@ -9,6 +9,10 @@ tf_module { mtype: "<type \'type\'>" } member_method { + name: "clone_model" + argspec: "args=[\'model\', \'input_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { name: "load_model" argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt index 3f54bc33e7..ba9e651b34 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt @@ -9,6 +9,10 @@ tf_module { argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " } member_method { + name: "expand_dims" + argspec: "args=[\'sp_input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { name: "eye" argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], " } diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index 0482cf619a..27b350e13e 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -27,7 +27,7 @@ function run_configure_for_gpu_build { } function set_remote_cache_options { - echo "build --remote_instance_name=projects/tensorflow-testing-cpu" >> "${TMP_BAZELRC}" + echo "build --remote_instance_name=projects/tensorflow-testing/instances/default_instance" >> "${TMP_BAZELRC}" echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}" echo "build --remote_cache=remotebuildexecution.googleapis.com" >> "${TMP_BAZELRC}" echo "build --tls_enabled=true" >> "${TMP_BAZELRC}" diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index a286e8a212..263f25bc48 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -1,3 +1,10 @@ +# WARNING: THESE IMAGES ARE DEPRECATED. + +TensorFlow's Dockerfiles are now located in +[`tensorflow/tools/dockerfiles/`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles). + +This directory will eventually be removed. + # Using TensorFlow via Docker This directory contains `Dockerfile`s to make it easy to get up and running with diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md new file mode 100644 index 0000000000..c484c162cb --- /dev/null +++ b/tensorflow/tools/dockerfiles/README.md @@ -0,0 +1,67 @@ +# TensorFlow Dockerfiles + +This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES +MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from +the files in `partials/` and the rules in `spec.yml`. See [the Maintaining +section](#maintaining) for more information. + +## Building + +The Dockerfiles in the `dockerfiles` directory must have their build context set +to **the directory with this README.md** to copy in helper files. For example: + +```bash +$ docker build -f ./dockerfiles/cpu.Dockerfile -t tf . +``` + +Each Dockerfile has its own set of available `--build-arg`s which are documented +in the Dockerfile itself. + +## Running + +After building the image with the tag `tf` (for example), use `docker run` to +run the images. Examples are below. + +Note for new Docker users: the `-v` and `-u` flags share directories between +the Docker container and your machine, and very important. Without +`-v`, your work will be wiped once the container quits, and without `-u`, files +created by the container will have the wrong file permissions on your host +machine. If you are confused, check out the [Docker run +documentation](https://docs.docker.com/engine/reference/run/). + +```bash +# Volume mount (-v) is optional but highly recommended, especially for Jupyter. +# User permissions (-u) are required if you use (-v). + +# CPU-based images +$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf + +# GPU-based images (set up nvidia-docker2 first) +$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf + +# Images with Jupyter run on port 8888, and needs a volume for notebooks +$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf +``` + +These images do not come with the TensorFlow source code -- but the development +images have git included, so you can `git clone` it yourself. + +## Contributing + +To make changes to TensorFlow's Dockerfiles, you'll update `spec.yml` and the +`*.partial.Dockerfile` files in the `partials` directory, then run +`assembler.py` to re-generate the full Dockerfiles before creating a pull +request. + +You can use the `Dockerfile` in this directory to build an editing environment +that has all of the Python dependencies you'll need: + +```bash +$ docker build -t tf-assembler -f assembler.Dockerfile . + +# Set --user to set correct permissions on generated files +$ docker run --user $(id -u):$(id -g) -it -v $(pwd):/tf tf-assembler bash + +# In the container... +/tf $ python3 ./assembler.py -o dockerfiles -s spec.yml +``` diff --git a/tensorflow/tools/dockerfiles/assembler.Dockerfile b/tensorflow/tools/dockerfiles/assembler.Dockerfile new file mode 100644 index 0000000000..7a8e07fced --- /dev/null +++ b/tensorflow/tools/dockerfiles/assembler.Dockerfile @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# TensorFlow Dockerfile Development Container +# +# You can use this image to quickly develop changes to the Dockerfile assembler +# or set of TF Docker partials. See README.md for usage instructions. +FROM debian:stretch +LABEL maintainer="Austin Anderson <angerson@google.com>" + +RUN apt-get update && apt-get install -y python3 python3-pip bash +RUN pip3 install --upgrade pip setuptools pyyaml absl-py cerberus + +WORKDIR /tf +VOLUME ["/tf"] + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/assembler.py b/tensorflow/tools/dockerfiles/assembler.py new file mode 100644 index 0000000000..9cdd9bb0cb --- /dev/null +++ b/tensorflow/tools/dockerfiles/assembler.py @@ -0,0 +1,554 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Assemble common TF Dockerfiles from many parts. + +This script constructs TF's Dockerfiles by aggregating partial +Dockerfiles. See README.md for usage examples. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import errno +import os +import os.path +import re +import shutil +import textwrap + +from absl import app +from absl import flags +import cerberus +import yaml + +FLAGS = flags.FLAGS + +flags.DEFINE_boolean( + 'dry_run', False, 'Do not actually generate Dockerfiles', short_name='n') + +flags.DEFINE_string( + 'spec_file', + './spec.yml', + 'Path to a YAML specification file', + short_name='s') + +flags.DEFINE_string( + 'output_dir', + './dockerfiles', ('Path to an output directory for Dockerfiles. ' + 'Will be created if it doesn\'t exist.'), + short_name='o') + +flags.DEFINE_string( + 'partial_dir', + './partials', + 'Path to a directory containing foo.partial.Dockerfile partial files.', + short_name='p') + +flags.DEFINE_boolean( + 'quiet_dry_run', + True, + 'Do not print contents of dry run Dockerfiles.', + short_name='q') + +flags.DEFINE_boolean( + 'validate', True, 'Validate generated Dockerfiles', short_name='c') + +# Schema to verify the contents of spec.yml with Cerberus. +# Must be converted to a dict from yaml to work. +# Note: can add python references with e.g. +# !!python/name:builtins.str +# !!python/name:__main__.funcname +SCHEMA_TEXT = """ +header: + type: string + +partials: + type: dict + keyschema: + type: string + valueschema: + type: dict + schema: + desc: + type: string + args: + type: dict + keyschema: + type: string + valueschema: + anyof: + - type: [ boolean, number, string ] + - type: dict + schema: + default: + type: [ boolean, number, string ] + desc: + type: string + options: + type: list + schema: + type: string + +images: + keyschema: + type: string + valueschema: + type: dict + schema: + desc: + type: string + arg-defaults: + type: list + schema: + anyof: + - type: dict + keyschema: + type: string + arg_in_use: true + valueschema: + type: string + - type: string + isimage: true + create-dockerfile: + type: boolean + partials: + type: list + schema: + anyof: + - type: dict + keyschema: + type: string + regex: image + valueschema: + type: string + isimage: true + - type: string + ispartial: true +""" + + +class TfDockerValidator(cerberus.Validator): + """Custom Cerberus validator for TF dockerfile spec. + + Note: Each _validate_foo function's docstring must end with a segment + describing its own validation schema, e.g. "The rule's arguments are...". If + you add a new validator, you can copy/paste that section. + """ + + def _validate_ispartial(self, ispartial, field, value): + """Validate that a partial references an existing partial spec. + + Args: + ispartial: Value of the rule, a bool + field: The field being validated + value: The field's value + + The rule's arguments are validated against this schema: + {'type': 'boolean'} + """ + if ispartial and value not in self.root_document.get('partials', dict()): + self._error(field, '{} is not an existing partial.'.format(value)) + + def _validate_isimage(self, isimage, field, value): + """Validate that an image references an existing partial spec. + + Args: + isimage: Value of the rule, a bool + field: The field being validated + value: The field's value + + The rule's arguments are validated against this schema: + {'type': 'boolean'} + """ + if isimage and value not in self.root_document.get('images', dict()): + self._error(field, '{} is not an existing image.'.format(value)) + + def _validate_arg_in_use(self, arg_in_use, field, value): + """Validate that an arg references an existing partial spec's args. + + Args: + arg_in_use: Value of the rule, a bool + field: The field being validated + value: The field's value + + The rule's arguments are validated against this schema: + {'type': 'boolean'} + """ + if arg_in_use: + for partial in self.root_document.get('partials', dict()).values(): + if value in partial.get('args', tuple()): + return + + self._error(field, '{} is not an arg used in any partial.'.format(value)) + + +def build_partial_description(partial_spec): + """Create the documentation lines for a specific partial. + + Generates something like this: + + # This is the partial's description, from spec.yml. + # --build-arg ARG_NAME=argdefault + # this is one of the args. + # --build-arg ANOTHER_ARG=(some|choices) + # another arg. + + Args: + partial_spec: A dict representing one of the partials from spec.yml. Doesn't + include the name of the partial; is a dict like { desc: ..., args: ... }. + + Returns: + A commented string describing this partial. + """ + + # Start from linewrapped desc field + lines = [] + wrapper = textwrap.TextWrapper( + initial_indent='# ', subsequent_indent='# ', width=80) + description = wrapper.fill(partial_spec.get('desc', '( no comments )')) + lines.extend(['#', description]) + + # Document each arg + for arg, arg_data in partial_spec.get('args', dict()).items(): + # Wrap arg description with comment lines + desc = arg_data.get('desc', '( no description )') + desc = textwrap.fill( + desc, + initial_indent='# ', + subsequent_indent='# ', + width=80, + drop_whitespace=False) + + # Document (each|option|like|this) + if 'options' in arg_data: + arg_options = ' ({})'.format('|'.join(arg_data['options'])) + else: + arg_options = '' + + # Add usage sample + arg_use = '# --build-arg {}={}{}'.format(arg, + arg_data.get('default', '(unset)'), + arg_options) + lines.extend([arg_use, desc]) + + return '\n'.join(lines) + + +def construct_contents(partial_specs, image_spec): + """Assemble the dockerfile contents for an image spec. + + It assembles a concrete list of partial references into a single, large + string. + Also expands argument defaults, so that the resulting Dockerfile doesn't have + to be configured with --build-arg=... every time. That is, any ARG directive + will be updated with a new default value. + + Args: + partial_specs: The dict from spec.yml["partials"]. + image_spec: One of the dict values from spec.yml["images"]. + + Returns: + A string containing a valid Dockerfile based on the partials listed in + image_spec. + """ + processed_partial_strings = [] + for partial_name in image_spec['partials']: + # Apply image arg-defaults to existing arg defaults + partial_spec = copy.deepcopy(partial_specs[partial_name]) + args = partial_spec.get('args', dict()) + for k_v in image_spec.get('arg-defaults', []): + arg, value = list(k_v.items())[0] + if arg in args: + args[arg]['default'] = value + + # Read partial file contents + filename = partial_spec.get('file', partial_name) + partial_path = os.path.join(FLAGS.partial_dir, + '{}.partial.Dockerfile'.format(filename)) + with open(partial_path, 'r') as f_partial: + partial_contents = f_partial.read() + + # Replace ARG FOO=BAR with ARG FOO=[new-default] + for arg, arg_data in args.items(): + if 'default' in arg_data and arg_data['default']: + default = '={}'.format(arg_data['default']) + else: + default = '' + partial_contents = re.sub(r'ARG {}.*'.format(arg), 'ARG {}{}'.format( + arg, default), partial_contents) + + # Store updated partial contents + processed_partial_strings.append(partial_contents) + + # Join everything together + return '\n'.join(processed_partial_strings) + + +def mkdir_p(path): + """Create a directory and its parents, even if it already exists.""" + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def construct_documentation(header, partial_specs, image_spec): + """Assemble all of the documentation for a single dockerfile. + + Builds explanations of included partials and available build args. + + Args: + header: The string from spec.yml["header"]; will be commented and wrapped. + partial_specs: The dict from spec.yml["partials"]. + image_spec: The spec for the dockerfile being built. + + Returns: + A string containing a commented header that documents the contents of the + dockerfile. + + """ + # Comment and wrap header and image description + commented_header = '\n'.join( + [('# ' + l).rstrip() for l in header.splitlines()]) + commented_desc = '\n'.join( + ['# ' + l for l in image_spec.get('desc', '').splitlines()]) + partial_descriptions = [] + + # Build documentation for each partial in the image + for partial in image_spec['partials']: + # Copy partial data for default args unique to this image + partial_spec = copy.deepcopy(partial_specs[partial]) + args = partial_spec.get('args', dict()) + + # Overwrite any existing arg defaults + for k_v in image_spec.get('arg-defaults', []): + arg, value = list(k_v.items())[0] + if arg in args: + args[arg]['default'] = value + + # Build the description from new args + partial_description = build_partial_description(partial_spec) + partial_descriptions.append(partial_description) + + contents = [commented_header, '#', commented_desc] + partial_descriptions + return '\n'.join(contents) + '\n' + + +def normalize_partial_args(partial_specs): + """Normalize the shorthand form of a partial's args specification. + + Turns this: + + partial: + args: + SOME_ARG: arg_value + + Into this: + + partial: + args: + SOME_ARG: + default: arg_value + + Args: + partial_specs: The dict from spec.yml["partials"]. This dict is modified in + place. + + Returns: + The modified contents of partial_specs. + + """ + for _, partial in partial_specs.items(): + args = partial.get('args', dict()) + for arg, value in args.items(): + if not isinstance(value, dict): + new_value = {'default': value} + args[arg] = new_value + + return partial_specs + + +def flatten_args_references(image_specs): + """Resolve all default-args in each image spec to a concrete dict. + + Turns this: + + example-image: + arg-defaults: + - MY_ARG: ARG_VALUE + + another-example: + arg-defaults: + - ANOTHER_ARG: ANOTHER_VALUE + - example_image + + Into this: + + example-image: + arg-defaults: + - MY_ARG: ARG_VALUE + + another-example: + arg-defaults: + - ANOTHER_ARG: ANOTHER_VALUE + - MY_ARG: ARG_VALUE + + Args: + image_specs: A dict of image_spec dicts; should be the contents of the + "images" key in the global spec.yaml. This dict is modified in place and + then returned. + + Returns: + The modified contents of image_specs. + """ + for _, image_spec in image_specs.items(): + too_deep = 0 + while str in map(type, image_spec.get('arg-defaults', [])) and too_deep < 5: + new_args = [] + for arg in image_spec['arg-defaults']: + if isinstance(arg, str): + new_args.extend(image_specs[arg]['arg-defaults']) + else: + new_args.append(arg) + + image_spec['arg-defaults'] = new_args + too_deep += 1 + + return image_specs + + +def flatten_partial_references(image_specs): + """Resolve all partial references in each image spec to a concrete list. + + Turns this: + + example-image: + partials: + - foo + + another-example: + partials: + - bar + - image: example-image + - bat + + Into this: + + example-image: + partials: + - foo + + another-example: + partials: + - bar + - foo + - bat + Args: + image_specs: A dict of image_spec dicts; should be the contents of the + "images" key in the global spec.yaml. This dict is modified in place and + then returned. + + Returns: + The modified contents of image_specs. + """ + for _, image_spec in image_specs.items(): + too_deep = 0 + while dict in map(type, image_spec['partials']) and too_deep < 5: + new_partials = [] + for partial in image_spec['partials']: + if isinstance(partial, str): + new_partials.append(partial) + else: + new_partials.extend(image_specs[partial['image']]['partials']) + + image_spec['partials'] = new_partials + too_deep += 1 + + return image_specs + + +def construct_dockerfiles(tf_spec): + """Generate a mapping of {"cpu": <cpu dockerfile contents>, ...}. + + Args: + tf_spec: The full spec.yml loaded as a python object. + + Returns: + A string:string dict of short names ("cpu-devel") to Dockerfile contents. + """ + names_to_contents = dict() + image_specs = tf_spec['images'] + image_specs = flatten_partial_references(image_specs) + image_specs = flatten_args_references(image_specs) + partial_specs = tf_spec['partials'] + partial_specs = normalize_partial_args(partial_specs) + + for name, image_spec in image_specs.items(): + if not image_spec.get('create-dockerfile', True): + continue + documentation = construct_documentation(tf_spec['header'], partial_specs, + image_spec) + contents = construct_contents(partial_specs, image_spec) + names_to_contents[name] = '\n'.join([documentation, contents]) + + return names_to_contents + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Unexpected command line args found: {}'.format(argv)) + + with open(FLAGS.spec_file, 'r') as spec_file: + tf_spec = yaml.load(spec_file) + + # Abort if spec.yaml is invalid + if FLAGS.validate: + schema = yaml.load(SCHEMA_TEXT) + v = TfDockerValidator(schema) + if not v.validate(tf_spec): + print('>> ERROR: {} is an invalid spec! The errors are:'.format( + FLAGS.spec_file)) + print(yaml.dump(v.errors, indent=2)) + exit(1) + else: + print('>> WARNING: Not validating {}'.format(FLAGS.spec_file)) + + # Generate mapping of { "cpu-devel": "<cpu-devel dockerfile contents>", ... } + names_to_contents = construct_dockerfiles(tf_spec) + + # Write each completed Dockerfile + if not FLAGS.dry_run: + print('>> Emptying destination dir "{}"'.format(FLAGS.output_dir)) + shutil.rmtree(FLAGS.output_dir, ignore_errors=True) + mkdir_p(FLAGS.output_dir) + else: + print('>> Skipping creation of {} (dry run)'.format(FLAGS.output_dir)) + for name, contents in names_to_contents.items(): + path = os.path.join(FLAGS.output_dir, name + '.Dockerfile') + if FLAGS.dry_run: + print('>> Skipping writing contents of {} (dry run)'.format(path)) + print(contents) + else: + mkdir_p(FLAGS.output_dir) + print('>> Writing {}'.format(path)) + with open(path, 'w') as f: + f.write(contents) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/tools/dockerfiles/bashrc b/tensorflow/tools/dockerfiles/bashrc new file mode 100644 index 0000000000..48cacf20f6 --- /dev/null +++ b/tensorflow/tools/dockerfiles/bashrc @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +export PS1="\[\e[31m\]tf-docker\[\e[m\] \[\e[33m\]\w\[\e[m\] > " +export TERM=xterm-256color +alias grep="grep --color=auto" +alias ls="ls --color=auto" + +echo -e "\e[1;31m" +cat<<TF +________ _______________ +___ __/__________________________________ ____/__ /________ __ +__ / _ _ \_ __ \_ ___/ __ \_ ___/_ /_ __ /_ __ \_ | /| / / +_ / / __/ / / /(__ )/ /_/ / / _ __/ _ / / /_/ /_ |/ |/ / +/_/ \___//_/ /_//____/ \____//_/ /_/ /_/ \____/____/|__/ + +TF +echo -e "\e[0;33m" + +if [[ $EUID -eq 0 ]]; then + cat <<WARN +WARNING: You are running this container as root, which can cause new files in +mounted volumes to be created as the root user on your host machine. + +To avoid this, run the container by specifying your user's userid: + +$ docker run -u \$(id -u):\$(id -g) args... +WARN +else + cat <<EXPL +You are running this container as user with ID $(id -u) and group $(id -g), +which should map to the ID and group for your user on the Docker host. Great! +EXPL +fi + +# Turn off colors +echo -e "\e[m" diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile new file mode 100644 index 0000000000..dbbad7d03a --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile @@ -0,0 +1,100 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, CPU-only environment for developing changes for TensorFlow, with Jupyter included. +# +# Start from Ubuntu, with TF development packages (no GPU support) +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the latest version of Bazel and Python development tools. +# +# Configure TensorFlow's shell prompt and login tools. +# +# Launch Jupyter on execution instead of a bash prompt. + +ARG UBUNTU_VERSION=16.04 +FROM ubuntu:${UBUNTU_VERSION} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + swig + +# Install bazel +RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \ + apt-get update && \ + apt-get install -y bazel + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter + +RUN mkdir /notebooks && chmod a+rwx /notebooks +RUN mkdir /.local && chmod a+rwx /.local +WORKDIR /notebooks +EXPOSE 8888 + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile new file mode 100644 index 0000000000..160d7c02e2 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, CPU-only environment for developing changes for TensorFlow. +# +# Start from Ubuntu, with TF development packages (no GPU support) +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the latest version of Bazel and Python development tools. +# +# Configure TensorFlow's shell prompt and login tools. + +ARG UBUNTU_VERSION=16.04 +FROM ubuntu:${UBUNTU_VERSION} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + swig + +# Install bazel +RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \ + apt-get update && \ + apt-get install -y bazel + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile new file mode 100644 index 0000000000..8d5d653ab7 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, CPU-only environment for using TensorFlow, with Jupyter included. +# +# Start from Ubuntu (no GPU support) +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the TensorFlow Python package. +# --build-arg TF_PACKAGE=tensorflow (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu) +# The specific TensorFlow Python package to install +# +# Configure TensorFlow's shell prompt and login tools. +# +# Launch Jupyter on execution instead of a bash prompt. + +ARG UBUNTU_VERSION=16.04 +FROM ubuntu:${UBUNTU_VERSION} + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +ARG TF_PACKAGE=tensorflow +RUN ${PIP} install ${TF_PACKAGE} + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter + +RUN mkdir /notebooks && chmod a+rwx /notebooks +RUN mkdir /.local && chmod a+rwx /.local +WORKDIR /notebooks +EXPOSE 8888 + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile new file mode 100644 index 0000000000..35c41b49fd --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, CPU-only environment for using TensorFlow +# +# Start from Ubuntu (no GPU support) +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the TensorFlow Python package. +# --build-arg TF_PACKAGE=tensorflow (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu) +# The specific TensorFlow Python package to install +# +# Configure TensorFlow's shell prompt and login tools. + +ARG UBUNTU_VERSION=16.04 +FROM ubuntu:${UBUNTU_VERSION} + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +ARG TF_PACKAGE=tensorflow +RUN ${PIP} install ${TF_PACKAGE} + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile new file mode 100644 index 0000000000..0f5fedf2fe --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile @@ -0,0 +1,120 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for TensorFlow, with Jupyter included. +# +# Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF development +# packages. +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the latest version of Bazel and Python development tools. +# +# Configure TensorFlow's shell prompt and login tools. +# +# Launch Jupyter on execution instead of a bash prompt. + +ARG UBUNTU_VERSION=16.04 +FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-dev-9-0 \ + cuda-cudart-dev-9-0 \ + cuda-cufft-dev-9-0 \ + cuda-curand-dev-9-0 \ + cuda-cusolver-dev-9-0 \ + cuda-cusparse-dev-9-0 \ + curl \ + git \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libcudnn7-dev=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libnccl-dev=2.2.13-1+cuda9.0 \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + wget \ + && \ + rm -rf /var/lib/apt/lists/* && \ + find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \ + rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a + +# Link NCCL libray and header where the build script expects them. +RUN mkdir /usr/local/cuda-9.0/lib && \ + ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \ + ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h + +# TODO(tobyboyd): Remove after license is excluded from BUILD file. +RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \ + cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/ + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + swig + +# Install bazel +RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \ + apt-get update && \ + apt-get install -y bazel + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter + +RUN mkdir /notebooks && chmod a+rwx /notebooks +RUN mkdir /.local && chmod a+rwx /.local +WORKDIR /notebooks +EXPOSE 8888 + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile new file mode 100644 index 0000000000..a6e280082e --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for TensorFlow. +# +# Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF development +# packages. +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the latest version of Bazel and Python development tools. +# +# Configure TensorFlow's shell prompt and login tools. + +ARG UBUNTU_VERSION=16.04 +FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-dev-9-0 \ + cuda-cudart-dev-9-0 \ + cuda-cufft-dev-9-0 \ + cuda-curand-dev-9-0 \ + cuda-cusolver-dev-9-0 \ + cuda-cusparse-dev-9-0 \ + curl \ + git \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libcudnn7-dev=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libnccl-dev=2.2.13-1+cuda9.0 \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + wget \ + && \ + rm -rf /var/lib/apt/lists/* && \ + find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \ + rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a + +# Link NCCL libray and header where the build script expects them. +RUN mkdir /usr/local/cuda-9.0/lib && \ + ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \ + ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h + +# TODO(tobyboyd): Remove after license is excluded from BUILD file. +RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \ + cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/ + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + swig + +# Install bazel +RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \ + apt-get update && \ + apt-get install -y bazel + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile new file mode 100644 index 0000000000..f1799113b1 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile @@ -0,0 +1,90 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow, with Jupyter included. +# +# NVIDIA with CUDA and CuDNN, no dev stuff +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the TensorFlow Python package. +# --build-arg TF_PACKAGE=tensorflow-gpu (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu) +# The specific TensorFlow Python package to install +# +# Configure TensorFlow's shell prompt and login tools. +# +# Launch Jupyter on execution instead of a bash prompt. + +FROM nvidia/cuda:9.0-base-ubuntu16.04 + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-9-0 \ + cuda-cufft-9-0 \ + cuda-curand-9-0 \ + cuda-cusolver-9-0 \ + cuda-cusparse-9-0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +ARG TF_PACKAGE=tensorflow-gpu +RUN ${PIP} install ${TF_PACKAGE} + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter + +RUN mkdir /notebooks && chmod a+rwx /notebooks +RUN mkdir /.local && chmod a+rwx /.local +WORKDIR /notebooks +EXPOSE 8888 + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile new file mode 100644 index 0000000000..690eb68b22 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile @@ -0,0 +1,79 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# below. Please refer to the the TensorFlow dockerfiles documentation for +# more information. Build args are documented as their default value. +# +# Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow. +# +# NVIDIA with CUDA and CuDNN, no dev stuff +# --build-arg UBUNTU_VERSION=16.04 +# ( no description ) +# +# Python is required for TensorFlow and other libraries. +# --build-arg USE_PYTHON_3_NOT_2=True +# Install python 3 over Python 2 +# +# Install the TensorFlow Python package. +# --build-arg TF_PACKAGE=tensorflow-gpu (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu) +# The specific TensorFlow Python package to install +# +# Configure TensorFlow's shell prompt and login tools. + +FROM nvidia/cuda:9.0-base-ubuntu16.04 + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-9-0 \ + cuda-cufft-9-0 \ + cuda-curand-9-0 \ + cuda-cusolver-9-0 \ + cuda-cusparse-9-0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ARG USE_PYTHON_3_NOT_2=True +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools + +ARG TF_PACKAGE=tensorflow-gpu +RUN ${PIP} install ${TF_PACKAGE} + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile new file mode 100644 index 0000000000..b08d8bdd14 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile @@ -0,0 +1,13 @@ +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + swig + +# Install bazel +RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ + curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \ + apt-get update && \ + apt-get install -y bazel diff --git a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile new file mode 100644 index 0000000000..2c9b9f3f9a --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile @@ -0,0 +1,8 @@ +RUN ${PIP} install jupyter + +RUN mkdir /notebooks && chmod a+rwx /notebooks +RUN mkdir /.local && chmod a+rwx /.local +WORKDIR /notebooks +EXPOSE 8888 + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile new file mode 100644 index 0000000000..f31b695e77 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile @@ -0,0 +1,43 @@ +ARG UBUNTU_VERSION=16.04 +FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-dev-9-0 \ + cuda-cudart-dev-9-0 \ + cuda-cufft-dev-9-0 \ + cuda-curand-dev-9-0 \ + cuda-cusolver-dev-9-0 \ + cuda-cusparse-dev-9-0 \ + curl \ + git \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libcudnn7-dev=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libnccl-dev=2.2.13-1+cuda9.0 \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + wget \ + && \ + rm -rf /var/lib/apt/lists/* && \ + find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \ + rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a + +# Link NCCL libray and header where the build script expects them. +RUN mkdir /usr/local/cuda-9.0/lib && \ + ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \ + ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h + +# TODO(tobyboyd): Remove after license is excluded from BUILD file. +RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \ + cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/ diff --git a/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile new file mode 100644 index 0000000000..13d865b9d4 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile @@ -0,0 +1,23 @@ +FROM nvidia/cuda:9.0-base-ubuntu16.04 + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-9-0 \ + cuda-cufft-9-0 \ + cuda-curand-9-0 \ + cuda-cusolver-9-0 \ + cuda-cusparse-9-0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* diff --git a/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile new file mode 100644 index 0000000000..6f346236a5 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile @@ -0,0 +1,12 @@ +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} install --upgrade \ + pip \ + setuptools diff --git a/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile new file mode 100644 index 0000000000..d641a11b06 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile @@ -0,0 +1,2 @@ +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile new file mode 100644 index 0000000000..96e79547f0 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile @@ -0,0 +1,2 @@ +ARG TF_PACKAGE +RUN ${PIP} install ${TF_PACKAGE} diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile new file mode 100644 index 0000000000..bc79272276 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile @@ -0,0 +1,24 @@ +ARG UBUNTU_VERSION=16.04 +FROM ubuntu:${UBUNTU_VERSION} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile new file mode 100644 index 0000000000..0a50735bf8 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile @@ -0,0 +1,2 @@ +ARG UBUNTU_VERSION=16.04 +FROM ubuntu:${UBUNTU_VERSION} diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml new file mode 100644 index 0000000000..28bf9a55da --- /dev/null +++ b/tensorflow/tools/dockerfiles/spec.yml @@ -0,0 +1,195 @@ +# ====== +# HEADER +# ====== +# +# This is commented-out and prepended to each generated Dockerfile. +header: | + Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================ + + THIS IS A GENERATED DOCKERFILE. + + This file was assembled from multiple pieces, whose use is documented + below. Please refer to the the TensorFlow dockerfiles documentation for + more information. Build args are documented as their default value. + +# ======== +# PARTIALS +# ======== +# +# Represent and document pieces of a Dockerfile. Spec: +# +# name: the name of the partial, is referenced from the images section +# desc: A description, inserted later into the Dockerfile +# file: Alternative file prefix, e.g. file.partial.Dockerfile. The default is +# the name of the partial. +# args: A dict of ARGs in the Dockerfile; each entry has the format +# ARG_NAME: VALUE where VALUE is one of: +# - a dict: +# desc: Documentation for the arg +# default: Default value for the arg; is written to the Dockerfile +# options: List of strings, part of documentation +# - a concrete value: the same as a dictionary with default: [value]. + +partials: + ubuntu: + desc: Start from Ubuntu (no GPU support) + args: + UBUNTU_VERSION: 16.04 + + ubuntu-devel: + desc: Start from Ubuntu, with TF development packages (no GPU support) + args: + UBUNTU_VERSION: 16.04 + + bazel: + desc: Install the latest version of Bazel and Python development tools. + + nvidia: + desc: NVIDIA with CUDA and CuDNN, no dev stuff + args: + UBUNTU_VERSION: 16.04 + + nvidia-devel: + desc: > + Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF + development packages. + args: + UBUNTU_VERSION: 16.04 + + python: + desc: Python is required for TensorFlow and other libraries. + args: + USE_PYTHON_3_NOT_2: + default: true + desc: Install python 3 over Python 2 + + tensorflow: + desc: Install the TensorFlow Python package. + args: + TF_PACKAGE: + default: tensorflow + options: + - tensorflow + - tensorflow-gpu + - tf-nightly + - tf-nightly-gpu + desc: The specific TensorFlow Python package to install + shell: + desc: Configure TensorFlow's shell prompt and login tools. + jupyter: + desc: Launch Jupyter on execution instead of a bash prompt. + +# ====== +# IMAGES +# ====== +# +# Represent Dockerfiles. Spec: +# +# name: the name of the image, possibly referenced by other images +# desc: A description, inserted later into the Dockerfile +# create-dockerfile: Create a dockerfile based on this. Useful for creating +# extensible base images that don't need a file. Default is true. +# partials: List of VALUEs, where a VALUE is either: +# - the name of a partial, which inserts that partial into this image +# - image: [name of another image], which inserts the partials from that +# image into this image +# arg-defaults: List of VALUEs, where a VALUE is either: +# - ARG_NAME: VALUE, which sets the ARG_NAME to VALUE wherever it appears +# in this image's partials +# - [name of another image], which loads the default args from that image +images: + + nodev: + create-dockerfile: false + partials: + - python + - tensorflow + - shell + + dev: + create-dockerfile: false + partials: + - python + - bazel + - shell + + cpu: + desc: Ubuntu-based, CPU-only environment for using TensorFlow + partials: + - ubuntu + - image: nodev + + cpu-devel: + desc: > + Ubuntu-based, CPU-only environment for developing changes for + TensorFlow. + partials: + - ubuntu-devel + - image: dev + + nvidia: + desc: Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow. + arg-defaults: + - TF_PACKAGE: tensorflow-gpu + partials: + - nvidia + - image: nodev + + nvidia-devel: + desc: > + Ubuntu-based, Nvidia-GPU-enabled environment for developing changes + for TensorFlow. + arg-defaults: + - TF_PACKAGE: tensorflow-gpu + partials: + - nvidia-devel + - image: dev + + cpu-jupyter: + desc: > + Ubuntu-based, CPU-only environment for using TensorFlow, with Jupyter + included. + partials: + - image: cpu + - jupyter + + cpu-devel-jupyter: + desc: > + Ubuntu-based, CPU-only environment for developing changes for + TensorFlow, with Jupyter included. + partials: + - image: cpu-devel + - jupyter + + nvidia-jupyter: + desc: > + Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow, with + Jupyter included. + arg-defaults: + - nvidia + partials: + - image: nvidia + - jupyter + + nvidia-devel-jupyter: + desc: > + Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for + TensorFlow, with Jupyter included. + arg-defaults: + - nvidia-devel + partials: + - image: nvidia-devel + - jupyter diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 9d0ce34344..34b4a66c41 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -493,11 +493,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/6203c9bd082a877a20c218033636712135a3c2db.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/6203c9bd082a877a20c218033636712135a3c2db.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/97d7bcd5c024ee6aec4eecbc723bb6d4f4c3dc3d.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/97d7bcd5c024ee6aec4eecbc723bb6d4f4c3dc3d.tar.gz", ], - sha256 = "83a80f9fb2a5949ca77e526344cbd4581388c3ec7fea5c59e488d46fd38e06d9", - strip_prefix = "llvm-6203c9bd082a877a20c218033636712135a3c2db", + sha256 = "2889b79ab979e676e344974cfeefbaf2c21c7c69a015bd584e8ae67b87b136bc", + strip_prefix = "llvm-97d7bcd5c024ee6aec4eecbc723bb6d4f4c3dc3d", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) |