diff options
Diffstat (limited to 'tensorflow/compiler')
78 files changed, 1580 insertions, 806 deletions
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index cf5c04ac4b..bd270045e3 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b95b063348..1c9d30d7b0 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -92,9 +93,8 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector<char>& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR( - WriteStringToFile(env, flags.out_function_object, - absl::string_view(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, + StringPiece(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5b6692f523..07c5b23188 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -29,18 +29,6 @@ cc_library( ) cc_library( - name = "parallel_check_op_flags", - srcs = ["parallel_check_op_flags.cc"], - hdrs = ["parallel_check_op_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( name = "xla_device_flags", srcs = ["xla_device_flags.cc"], hdrs = ["xla_device_flags.h"], diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc deleted file mode 100644 index a61694b494..0000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include <mutex> -#include <vector> - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ParallelCheckOpFlags* flags; -static std::vector<Flag>* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ParallelCheckOpFlags; - flags->parallel_check_failfast = true; - flags->parallel_check_atol = "1e-5"; - flags->parallel_check_rtol = "1e-5"; - flag_list = new std::vector<Flag>({ - Flag("parallel_check_failfast", &flags->parallel_check_failfast, - "Fail immediately on first parallel-check comparison error."), - Flag("parallel_check_atol", &flags->parallel_check_atol, - "Absolute error tolerance for parallel-check comparison."), - Flag("parallel_check_rtol", &flags->parallel_check_rtol, - "Relative error tolerance for parallel-check comparison."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h deleted file mode 100644 index 156a2a2a71..0000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include <vector> - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list); - -// The values of flags associated with the XLA bridge's -// parallel_check_op module. -typedef struct { - bool parallel_check_failfast; // Fail immediately on first parallel-check - // comparison error. - string parallel_check_atol; // Absolute error tolerance for parallel-check - // comparison. - string parallel_check_rtol; // Relative error tolerance for parallel-check - // comparison. -} ParallelCheckOpFlags; - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9473ac0a4c..807ab51fd3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](absl::string_view name, Graph* graph) { + auto BuildNoopNode = [](StringPiece name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 17ae510a0e..debd9038c7 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index af83c792e5..6d4160a968 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, + StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { - manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, - done); + manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name), + device, cpu_tensor, done); } void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index df82421294..1effd6628f 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { @@ -110,9 +111,12 @@ class XlaDeviceContext : public DeviceContext { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const override; + // TODO(rlahaye): Replace StringPiece with absl::string_view when the + // StringPiece->absl::string_view change is rolled forward. void CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done) override; + StringPiece tensor_name, // non-ABSL OK + Device* device, Tensor* cpu_tensor, + StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 22be7f048f..3821dced63 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -191,6 +191,7 @@ cc_library( ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", + ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -214,6 +215,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -359,6 +361,7 @@ tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ + ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", @@ -370,6 +373,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", @@ -631,3 +635,12 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], + deps = [ + "//tensorflow/core:core_cpu", + ], +) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4c776fb178..46794f7b50 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -115,9 +115,6 @@ tf_kernel_library( deps = [ ":if_op", ":while_op", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", @@ -168,14 +165,11 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - ] + if_mkl( - [ - "//tensorflow/core/kernels:mkl_transpose_op", - ], - [ - "//tensorflow/core/kernels:transpose_op", - ], - ), + "//tensorflow/core/kernels:transpose_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], ) tf_kernel_library( @@ -184,6 +178,7 @@ tf_kernel_library( hdrs = ["while_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", @@ -201,6 +196,7 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6e1dbf5472..56da50f140 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } ctx->SetOutput(i, output_handle); } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the conditional // bodies. diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index f9bc98a198..7783e13a8a 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel { DataType cond_type_; DataTypeVector input_types_; DataTypeVector output_types_; + bool has_token_input_output_; + std::vector<string> token_input_nodes_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 296518229e..559414eeaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { cond_name_attr_ = *name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); body_name_attr_ = *name_attr; + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; body_options.is_entry_computation = false; + body_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; cond_options.is_entry_computation = false; + cond_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); @@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "while" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(builder, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); @@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, i)); } } + if (has_token_input_output_) { + // Set token output for this "while" op. + xla::XlaOp token_output = + xla::GetTupleElement(while_result, ctx->num_outputs()); + auto shape_or = builder->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the loop. for (int i = 0; i < body.resource_updates.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 67edebabf9..aeeff40e68 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel { private: NameAttrList cond_name_attr_; NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector<string> token_input_nodes_; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 20f2ce2919..92577b5bc8 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -30,11 +31,10 @@ namespace tensorflow { } } -static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* -CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>; +static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap<StringPiece, XlaResourceOpInfo>; - auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, + auto add = [&](StringPiece op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,17 +103,17 @@ CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& +static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& GetStaticResourceOpInfoMap() { - static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map = + static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos = + const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos = GetStaticResourceOpInfoMap(); - auto it = op_infos.find(op); + auto it = op_infos.find(StringPiece(op.data(), op.length())); return it == op_infos.end() ? nullptr : &it->second; } @@ -121,7 +121,7 @@ namespace resource_op_table_internal { std::vector<absl::string_view> GetKnownResourceOps() { std::vector<absl::string_view> result; for (const auto& p : GetStaticResourceOpInfoMap()) { - result.push_back(p.first); + result.push_back(absl::string_view(p.first)); } absl::c_sort(result); return result; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc new file mode 100644 index 0000000000..6cd7b24592 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/side_effect_util.h" + +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; + +const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; + +std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) { + std::set<std::string> results; + Node* first_side_effecting_node_on_path = nullptr; + ReverseDFS(g, + [&](Node* n) { + std::vector<string> token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + results.insert(n->name()); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); + return results; +} + +bool HasSideEffectingNodes(const Graph& g) { + for (Node* n : g.nodes()) { + std::vector<string> token_input_nodes; + if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) + .ok() && + !token_input_nodes.empty()) { + return true; + } + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 0000000000..ad07624729 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include <vector> + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index a29e764466..dcddef8418 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -18,6 +18,7 @@ limitations under the License. #include <unordered_map> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 41d305d461..dcb455779d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" @@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, "Invalid resource type in XLAShapeForArgument()"); } } + case XlaCompiler::Argument::kToken: { + *xla_shape = xla::ShapeUtil::MakeTokenShape(); + return Status::OK(); + } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); } @@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments( } break; - case XlaCompiler::Argument::kParameter: { + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } @@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments( arg_expression.set_handle(arg_handles[i]); } break; + case XlaCompiler::Argument::kToken: { + arg_expression.set_handle(arg_handles[i]); + break; + } case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, &options_.shape_representation_fn); core::ScopedUnref context_unref(context); + std::vector<XlaCompiler::Argument> real_args(args); + int token_input_index = -1; + if (options.add_token_input_output) { + // Add extra token input. + token_input_index = real_args.size(); + + XlaCompiler::Argument token_arg; + token_arg.kind = XlaCompiler::Argument::kToken; + real_args.push_back(token_arg); + } + std::vector<XlaExpression> arg_expressions; std::vector<int> arg_cores; - TF_RETURN_IF_ERROR( - BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, - &arg_cores, &arg_expressions, &result->input_mapping, - &result->xla_input_shapes, options.is_entry_computation)); + TF_RETURN_IF_ERROR(BuildArguments( + *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes, + options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + PushNodeTokenMapping(); + // Use std::set instead of std::unordered_set to ensure determinism. + std::set<std::string> output_node_token_inputs; + if (token_input_index != -1) { + // Original token comes from input. + auto arg_expression = context->args()[token_input_index]; + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); + + // Calculate token inputs for output token. + output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); + + // If there's no side-effecting op in the graph, use token input as token + // output. + if (output_node_token_inputs.empty()) { + output_node_token_inputs.insert(kXlaTokenArgNodeName); + } + } else if (options.is_entry_computation) { + // Original token is manually created. + if (HasSideEffectingNodes(*graph)) { + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + } + } + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, flib_runtime_, NextStepId())); + if (token_input_index != -1) { + // Add extra token output. + std::vector<xla::XlaOp> token_inputs; + for (const auto& node_name : output_node_token_inputs) { + auto token_or = GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + TF_RETURN_IF_ERROR( + context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + } + TF_RETURN_IF_ERROR(PopNodeTokenMapping()); int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared<xla::XlaComputation>(); result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( - args, arg_cores, context->retvals(), context->resources(), + real_args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, @@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency( return Status::OK(); } +void XlaCompiler::PushNodeTokenMapping() { + node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{}); +} + +Status XlaCompiler::PopNodeTokenMapping() { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " + "empty."); + } + node_token_mapping_stack_.pop(); + return Status::OK(); +} + +Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp& op) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling SetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); + if (!insert_result.second) { + return errors::FailedPrecondition("Token mapping already exists for node ", + node_name); + } + return Status::OK(); +} + +xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling GetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto iter = node_token_mapping_stack_.top().find(node_name); + if (iter == node_token_mapping_stack_.top().end()) { + return errors::FailedPrecondition("Cannot find token mapping for node ", + node_name); + } + return iter->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 8f4a9858ed..2cc603a580 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include <stack> + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,6 +185,9 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { @@ -384,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -448,6 +462,15 @@ class XlaCompiler { std::unordered_map<string, xla::XlaOp> host_compute_control_output_; + // This is used to store <node name, token output> mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index be3c93ae47..40ce9fb41c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { } } +class DummySideEffectingOp : public XlaOpKernel { + public: + explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken( + name(), xla::CreateToken(ctx->builder()))); + } +}; + +REGISTER_OP("DummySideEffectingOp"); + +REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp); + +TEST_F(XlaCompilerTest, TokenInputAndOutput) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + NodeDef side_effecting_op; + side_effecting_op.set_name("DummySideEffectingOp"); + side_effecting_op.set_op("DummySideEffectingOp"); + AddNodeAttr(kXlaTokenInputNodesAttrName, + std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op); + Status status; + graph->AddNode(side_effecting_op, &status); + TF_ASSERT_OK(status); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); + + const std::vector<XlaCompiler::Argument> empty_args; + { + // The case for entry computation: we don't add token input/output. Instead, + // we use CreateToken HLO to create the entry token. + XlaCompiler::CompileOptions options; + options.is_entry_computation = true; + options.add_token_input_output = false; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 0); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + } + { + // The case for non-entry computation (e.g. while loop body). We add token + // input/output. + XlaCompiler::CompileOptions options; + options.is_entry_computation = false; + options.add_token_input_output = true; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken( + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e8b4b0eb36..f247570d72 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { return Status::OK(); } +Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { + VLOG(1) << "Adding retval index " << retvals_.size() + << " with token to XLA computation"; + XlaExpression e; + e.set_handle(token); + // We use DT_INVALID because there is no TF DataType which corresponds to XLA + // token. XlaCompiler handles this case separately, so putting it here is OK. + retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e..d7dbdc957f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -89,6 +89,9 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are resource handles. Status AddResourceRetval(int retval_index, XlaResource* resource); + // As for Retval, but for return values that are XLA tokens. + Status AppendTokenRetval(const xla::XlaOp& token); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index d67e50375b..636cb71e21 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -102,7 +102,8 @@ Status XlaOpKernelContext::ConstantInput(int index, static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context, absl::string_view name) { int start, stop; - TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(context->op_kernel().InputRange( + StringPiece(name.data(), name.length()), &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, @@ -365,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, std::vector<TensorShape>* shapes) { OpInputList inputs; - TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); + TF_RETURN_IF_ERROR( + context_->input_list(StringPiece(name.data(), name.size()), &inputs)); handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { @@ -378,7 +380,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name, Status XlaOpKernelContext::ConstantInputList( absl::string_view name, std::vector<xla::Literal>* outputs) { int start, stop; - TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(op_kernel().InputRange( + StringPiece(name.data(), name.size()), &start, &stop)); outputs->resize(stop - start); for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); @@ -612,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; - CHECK(context_->input(name, &tensor).ok()); + CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok()); return *tensor; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 74a4885f1f..5d53169f68 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index f9473d372b..bddb664149 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -64,7 +65,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read( absl::Span<const float> field = result->data<float>(); char* data = absl::bit_cast<char*>(field.data()); uint64 bytes = elements * sizeof(float); - absl::string_view sp; + tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +86,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - absl::string_view sp; + tensorflow::StringPiece sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 76c09512d8..450d3fe5af 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,12 +109,12 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "third_party/absl/strings/str_cat.h" -#include "third_party/absl/strings/str_format.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ab86dce510..e784663ff6 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -159,6 +159,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -291,6 +292,7 @@ cc_library( "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", ], hdrs = [ @@ -303,6 +305,7 @@ cc_library( "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", ], deps = [ @@ -331,6 +334,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -1037,7 +1042,6 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1065,7 +1069,6 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", - ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1086,7 +1089,6 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1108,7 +1110,6 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", - ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1177,22 +1178,6 @@ cc_library( ], ) -cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], @@ -1202,7 +1187,6 @@ tf_cc_test( ":hlo_dce", ":hlo_ordering", ":hlo_parser", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1222,7 +1206,6 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1969,6 +1952,8 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1977,6 +1962,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -2413,7 +2399,6 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", @@ -2587,6 +2572,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index aa40fba9bb..a0db4563fb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2369,20 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) { rhs_pad->shape().dimensions(3), testcase.orig_conv_window)) .ValueOrDie(); - auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, window, - dnums) - .ValueOrDie(), - input, rhs_pad, /*feature_group_count=*/1, window, dnums, - DefaultPrecisionConfig(2))); // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. PrecisionConfig precision_config; precision_config.add_operand_precision(PrecisionConfig::HIGH); precision_config.add_operand_precision(PrecisionConfig::HIGHEST); - orig_conv->set_precision_config(precision_config); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + precision_config)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2401,7 +2401,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) { conv->operand(1)->shape().dimensions(2), conv->operand(1)->shape().dimensions(3), testcase.expected_conv_window)); - EXPECT_THAT(conv->precision_config().operand_precision(), + EXPECT_THAT(Cast<HloConvolutionInstruction>(conv) + ->precision_config() + .operand_precision(), ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 69b654d30e..388fd5df99 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloTestBase { +class BFloat16PropagationTest : public HloVerifiedTestBase { protected: + BFloat16PropagationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. bool PropagatePrecision(HloModule* module) { @@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase { inst->users()[0]->opcode() == HloOpcode::kConvert && inst->users()[0]->shape().element_type() == BF16; } + + std::unique_ptr<HloInstruction> CreateDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + DefaultPrecisionConfig(2)); + } }; // Tests that BF16 can propagate through select over non-tuple buffers, but not @@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); - HloInstruction* root = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple0->shape(), tuple1, 0)), 0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); @@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { // Tests that a non-fusion computation's root should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* a = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); @@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add)); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); @@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); HloInstruction* b_f1 = builder_f1.AddInstruction( HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); - HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1)); auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); HloInstruction* add_f = builder_f.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); - HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + HloInstruction* dot_f = + builder_f.AddInstruction(CreateDot(shape, add_f, add_f)); auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { HloInstruction::CreateGetTupleElement(shape, fusion, 0)); HloInstruction* gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, fusion, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); auto body_param = builder_body.AddInstruction( HloInstruction::CreateParameter(0, shape, "body_param")); - auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, body_param, body_param)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_param, body_param)); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest, HloInstruction::CreateParameter(0, shape, "cond_param")); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1}, + {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest, auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // This add should prevent RHS from using BF16 auto cond_add_rhs = builder_cond.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot1 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); - auto body_dot2 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_dot1 = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); + auto body_dot2 = + builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs)); auto body_transpose = builder_body.AddInstruction( HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( @@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); auto rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond0_add_rhs = builder_cond0.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); - auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + auto cond0_dot = + builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); builder_cond0.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond1_add_lhs = builder_cond1.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); - auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + auto cond1_dot = + builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); builder_cond1.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. @@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); builder_body.AddInstruction( HloInstruction::CreateTuple({body_dot, body_rhs})); auto body = module->AddEmbeddedComputation(builder_body.Build()); @@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); - auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 1)))); - auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 1)))); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto lhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction::CreateGetTupleElement(shape, domain, 0)); HloInstruction* b_gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, domain, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); HloInstruction* b_trans = builder.AddInstruction( HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* dot = + builder.AddInstruction(CreateDot(shape, a_trans, b_trans)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d412578619..2368ac8c6a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -670,6 +670,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 0fea462c85..7d99b914d4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace op = xla::testing::opcode_matchers; @@ -696,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, auto* addend = builder.AddInstruction( HloInstruction::CreateParameter(2, dot_shape, "param2")); - auto* dot = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + auto* dot = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); builder.AddInstruction( HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 9363af3b89..4668f3872d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion( HloInstruction::CreateParameter(1, dot_shape, "param1")); HloInstruction* dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); - HloInstruction* dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* dot_result = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); HloInstruction* add_result; if (dot_operand_idx_in_add == 0) { add_result = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index a84ee78b19..fad76338a5 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -35,9 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 2384166fd2..f11aff0573 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -121,6 +121,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index fcd87b36b3..18ee25ba91 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { HloInstruction* rhs = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "input")); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } @@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { HloInstruction* lhs_transposed = builder.AddInstruction( HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13ccff35f8..6791e15ee0 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -480,6 +481,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -813,7 +815,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], @@ -831,6 +832,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 0922e44a12..59ade96f7d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* add = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index bca775c475..96bfe0c12e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index ffca5d6549..b7c37bcf3c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -764,5 +764,20 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement( return Load(return_buffer); } +std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs( + const HloInstruction& hlo) { + std::vector<llvm_ir::IrArray> output_arrays; + if (ShapeUtil::IsTuple(hlo.shape())) { + int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays.reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays.push_back(GetIrArray(hlo, hlo)); + } + return output_arrays; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 579268f071..8805201480 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -124,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } + + // Generates the IrArray for each output of an hlo instruction and returns + // a vector containing such IrArrays. + std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs( + const HloInstruction& hlo); + // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( const HloInstruction& hlo, const ShapeIndex& index = {}) const { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5c827e5f9c..66c65f6975 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { - const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); - std::vector<llvm_ir::IrArray> target_arrays; - target_arrays.reserve(num_elems); - for (int64 i = 0; i != num_elems; ++i) { - target_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector<llvm_ir::IrArray> target_arrays = + ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - - std::vector<llvm::Value*> tuple_operand_ptrs; - tuple_operand_ptrs.reserve(num_elems); - for (const llvm_ir::IrArray& array : target_arrays) { - tuple_operand_ptrs.push_back(array.GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 389a98facb..f91cc00d71 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2521,15 +2521,15 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( } StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index) { + HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value_operand = [&] { + HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: - return inst->operand(2); + return inst->mutable_operand(2); case HloOpcode::kReduce: - return inst->operand(1); + return inst->mutable_operand(1); case HloOpcode::kTuple: CHECK(hlo->IsMultiOutputFusion()) << ": " << hlo->ToString() << " is not a multi-output fusion."; @@ -2537,7 +2537,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( << ": Found '" << inst->operand(index.back())->opcode() << "' in " << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. - return inst->operand(index.back())->operand(1); + return inst->mutable_operand(index.back())->mutable_operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2609,28 +2609,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); - // If the init_value was fused into this reduce we have to generate it first. - if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { - CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - const Literal& literal = init_value_operand->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); + if (fused) { + // If init_value was fused into this reduce we have to generate it first. + std::vector<IrArray> parameter_arrays; + for (HloInstruction* operand : hlo->operands()) { + parameter_arrays.push_back(GetIrArray(*operand, *hlo)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - global_for_const->setAlignment(kConstantBufferAlignBytes); - bindings_.BindHloToIrValue(*init_value_operand, global_for_const); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } else { + // In the unfused case the element is already there, just read from it. + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &b_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &b_); - }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) - .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2819,10 +2826,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( } // For multioutput fusion, we need to emit each operand and the root. - std::vector<IrArray> output_arrays; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2830,12 +2834,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel( &hlo, launch_dimensions.launch_bound(), &b_))); - std::vector<llvm::Value*> tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); } @@ -2847,29 +2848,14 @@ Status IrEmitterUnnested::EmitTargetElementLoop( static_cast<KernelThunk*>(LastThunk())); } -int IrEmitterUnnested::ConstructIrArrayForOutputs( - const HloInstruction& hlo, std::vector<IrArray>* output_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_arrays->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_arrays->push_back(GetIrArray(hlo, hlo, {i})); - } - } else { - output_arrays->push_back(GetIrArray(hlo, hlo)); - } - return num_outputs; -} - -int IrEmitterUnnested::ConstructIrArrayForInputs( - const HloInstruction& hlo, std::vector<IrArray>* param_arrays) { - int64 num_params = hlo.operands().size(); - param_arrays->reserve(num_params); +std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo) { + std::vector<IrArray> param_arrays; + param_arrays.reserve(hlo.operands().size()); for (const HloInstruction* param : hlo.operands()) { - param_arrays->push_back(GetIrArray(*param, hlo)); + param_arrays.push_back(GetIrArray(*param, hlo)); } - return num_params; + return param_arrays; } int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( @@ -3050,10 +3036,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( constexpr int64 kThreadsPerTile = kTileSize * kNumRows; // Construct IrArrays for the inputs and outputs. - std::vector<IrArray> output_arrays; - int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); - std::vector<IrArray> param_arrays; - int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo); + int64 num_outputs = output_arrays.size(); + std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr); @@ -3251,12 +3237,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // For multioutput fusion, emit a tuple with all the individual outputs. if (hlo->IsMultiOutputFusion()) { - std::vector<llvm::Value*> tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, - module_); + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); } return launch_dimensions; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 084462330e..bd5db72051 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -193,14 +193,12 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span<const int64> reduced_output_dims, absl::Span<const int64> tiled_param_ids); - // Generates the IrArray for each output of hlo and returns the number of - // outputs. - int ConstructIrArrayForOutputs(const HloInstruction& hlo, - std::vector<llvm_ir::IrArray>* output_arrays); - // Generates the IrArray for each input of hlo and returns the number of - // inputs. - int ConstructIrArrayForInputs(const HloInstruction& hlo, - std::vector<llvm_ir::IrArray>* param_arrays); + + // Generates the IrArray for each input of an hlo and returns a vector that + // constains such IrArrays. + std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs( + const HloInstruction& hlo); + // For each output of the `hlo` instruction, constructs the reduced shape for // the output with the given `reduced_output_dims` and cast the original // output IrArray element in `output_arrays` to the reduced shape. Returns @@ -244,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index = {}); + HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 091aca23e5..8f0dedfa40 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 99d0cf50ca..93ec2c9438 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -199,6 +199,17 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map<int64, InstructionSequence> sequences = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -214,16 +225,9 @@ message HloModuleProto { // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; } // Serialization of LogicalBuffer. @@ -322,8 +326,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index fe7f2be888..233d2199d1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span<const HloInstruction* const> instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fe2d3bbbe5..91c5234a6f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -170,6 +170,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span<const HloInstruction* const> instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 939b5114c3..a502fff9a0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { + // Domain does not have any computation or data transfer. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); @@ -507,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { valid_position_counts.push_back(valid_position_count); } - const int64 fma_count = - input_feature * output_feature * batch * Product(valid_position_counts); + const int64 fma_count = (input_feature / convolution->feature_group_count()) * + output_feature * batch * + Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 9bb3f12ee2..46b4bbeef2 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; + Status HandleDomain(const HloInstruction* domain) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18..d76ce9ecbc 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) { sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } +TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { + XlaBuilder builder("convolution"); + auto input = Parameter( + &builder, 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = Parameter( + &builder, 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x120x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18)); +} + TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = @@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 406d712ec6..e09d5868f2 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloTestBase { +class HloCseTest : public HloVerifiedTestBase { protected: HloCseTest() {} }; @@ -65,13 +65,13 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get<float>({})); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0<float>(84.0); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -96,14 +96,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); EXPECT_THAT(add, op::Add(first_operand, first_operand)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -128,12 +128,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })") - .ValueOrDie(); + })"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule m add_computation { @@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })") - .ValueOrDie(); + })"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -735,13 +730,11 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})") - .ValueOrDie(); +})"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - LOG(INFO) << "AAAAA " << module->ToString(); - const HloInstruction* sub = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + const HloInstruction* sub = module().entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index abd4bb1f73..102ebb24ab 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,10 +52,7 @@ static std::array<bool, 2> use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, public HloVerifiedTestBase { protected: - HloEvaluatorTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique<HloEvaluator>(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 6a09bb08f4..63303aef1e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1052,7 +1052,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, rhs_literal_data, - feature_group_count](absl::Span<const int64> out_index) { + feature_group_count](const absl::Span<const int64> out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1063,9 +1063,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_batch_dim = dnums.output_batch_dimension(); const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 input_z_size = + ShapeUtil::GetDimension(lhs_shape, input_z_dim); + // The size of an input feature group. + const int64 input_feature_group_size = input_z_size / feature_group_count; + const int64 output_z_size = ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); + // The output feature dimension is a concatenation of convolution results + // from the different groups. + const int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current output index + // belongs. + const int64 feature_group_index = + out_index[output_z_dim] / output_feature_group_size; ElementwiseT result_val = static_cast<ElementwiseT>(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1073,33 +1086,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 rhs_iz = iz; - // Handle grouped convolutions. - if (feature_group_count > 1) { - // The size of a feature group. - int64 feature_group_size = z_size / feature_group_count; - rhs_iz = iz % feature_group_size; - - // The output feature dimension is a concatenation of convolution - // results from the different groups. - int64 output_feature_group_size = - output_z_size / feature_group_count; - - // Calculate the group index to which the current input feature - // index belongs. - int64 input_group_index = iz / feature_group_size; - - // Calculate the group index to which the current output index - // belongs. - int64 output_group_index = - out_index[output_z_dim] / output_feature_group_size; - if (input_group_index != output_group_index) { - // If the current output index does not belong to the current - // feature group, skip it. - continue; - } - } + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { + const int64 iz = + feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 471a12d6aa..25ae344ea5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -451,6 +451,28 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; + case HloOpcode::kDot: { + TF_RET_CHECK(proto.has_dot_dimension_numbers()) + << "Dot instruction should have dot_dimension_numbers."; + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Dot instruction should have 2 operands but sees " + << proto.operand_ids_size(); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = absl::make_unique<HloDotInstruction>( + proto.shape(), operands(0), operands(1), + proto.dot_dimension_numbers(), precision_config); + break; + } + case HloOpcode::kDomain: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Domain instruction should have 1 operands but sees " + << proto.operand_ids_size(); + instruction = absl::make_unique<HloDomainInstruction>( + proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, + /*user_side_metadata=*/nullptr); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -472,20 +494,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } - if (instruction->opcode() == HloOpcode::kDot) { - instruction->precision_config_ = proto.precision_config(); - instruction->precision_config_.mutable_operand_precision()->Resize( - instruction->operand_count(), PrecisionConfig::DEFAULT); - TF_RET_CHECK(proto.has_dot_dimension_numbers()); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>( - proto.dot_dimension_numbers()); - } else { - TF_RET_CHECK(!proto.has_precision_config()) - << instruction->opcode() << proto.DebugString(); - TF_RET_CHECK(!proto.has_dot_dimension_numbers()) - << instruction->opcode(); - } + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); break; } } @@ -564,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -596,7 +606,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kGe: case HloOpcode::kGt: @@ -674,30 +683,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(dimension_numbers); - instruction->set_precision_config(precision_config); - return instruction; -} - -/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(); - instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); - instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); - return instruction; + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dimension_numbers, precision_config); } /* static */ std::unique_ptr<HloInstruction> @@ -1157,12 +1144,9 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, std::unique_ptr<DomainMetadata> operand_side_metadata, std::unique_ptr<DomainMetadata> user_side_metadata) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); - instruction->operand_side_metadata_ = std::move(operand_side_metadata); - instruction->user_side_metadata_ = std::move(user_side_metadata); - instruction->AppendOperand(operand); - return instruction; + return absl::make_unique<HloDomainInstruction>( + shape, operand, std::move(operand_side_metadata), + std::move(user_side_metadata)); } std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( @@ -1218,6 +1202,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kIota: + case HloOpcode::kDot: + case HloOpcode::kDomain: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1290,11 +1276,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kDot: - CHECK_EQ(new_operands.size(), 2); - clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_, precision_config()); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); @@ -1319,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kDomain: - CHECK_EQ(new_operands.size(), 1); - clone = - CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), - user_side_metadata_->Clone()); - break; case HloOpcode::kAfterAll: if (new_operands.empty()) { clone = CreateToken(); @@ -1620,11 +1595,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAfterAll: return false; - // Check dot dimension numbers. - case HloOpcode::kDot: - return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - other.dot_dimension_numbers()); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1640,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } - case HloOpcode::kDomain: - return operand_side_metadata().Matches(other.operand_side_metadata()) && - user_side_metadata().Matches(other.user_side_metadata()); - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1683,6 +1649,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kScatter: + case HloOpcode::kDot: + case HloOpcode::kDomain: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2052,15 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector<string> extra = ExtraAttributesToStringImpl(options); - if (dot_dimension_numbers_ != nullptr) { - extra.push_back(DotDimensionNumbersToString()); - } - - string precision_config_string = PrecisionConfigToString(); - if (!precision_config_string.empty()) { - extra.push_back(precision_config_string); - } - if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2146,11 +2105,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( }), "}")); } - if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { - extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", user_side_metadata_->ToString(), - ", exit=", operand_side_metadata_->ToString(), "}")); - } return extra; } @@ -2182,19 +2136,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) { - *proto.mutable_precision_config() = precision_config_; - } if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - if (dot_dimension_numbers_ != nullptr) { - *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; - } - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } @@ -2921,31 +2868,6 @@ string ConvolutionDimensionNumbersToString( StrJoin(output_dims, "")); } -string HloInstruction::DotDimensionNumbersToString() const { - std::vector<string> result; - if (dot_dimension_numbers_ == nullptr) { - return ""; - } - const DotDimensionNumbers& dnums = *dot_dimension_numbers_; - if (!dnums.lhs_batch_dimensions().empty()) { - result.push_back(StrCat("lhs_batch_dims={", - StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("lhs_contracting_dims={", - StrJoin(dnums.lhs_contracting_dimensions(), ","), - "}")); - - if (!dnums.rhs_batch_dimensions().empty()) { - result.push_back(StrCat("rhs_batch_dims={", - StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("rhs_contracting_dims={", - StrJoin(dnums.rhs_contracting_dimensions(), ","), - "}")); - - return StrJoin(result, ", "); -} - StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { static std::unordered_map<string, RandomDistribution>* map = [] { static auto* map = new std::unordered_map<string, RandomDistribution>; @@ -2964,27 +2886,6 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { return found->second; } -string HloInstruction::PrecisionConfigToString() const { - if (absl::c_all_of( - precision_config_.operand_precision(), [](int32 precision) { - return static_cast<PrecisionConfig::Precision>(precision) == - PrecisionConfig::DEFAULT; - })) { - return ""; - } - return StrCat( - "operand_precision={", - StrJoin( - precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; - StrAppend(out, - PrecisionToString( - static_cast<PrecisionConfig::Precision>(precision))); - }), - "}"); -} - StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) { static std::unordered_map<string, PrecisionConfig::Precision>* map = [] { static auto* map = @@ -3044,6 +2945,16 @@ Status HloInstruction::set_backend_config( return ret; } +const PrecisionConfig& HloInstruction::precision_config() const { + if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) { + return convolution->precision_config(); + } + if (auto* dot = DynCast<HloDotInstruction>(this)) { + return dot->precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3348,4 +3259,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers(); } +const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { + return Cast<HloDotInstruction>(this)->dot_dimension_numbers(); +} + +const DomainMetadata& HloInstruction::operand_side_metadata() const { + return Cast<HloDomainInstruction>(this)->operand_side_metadata(); +} + +const DomainMetadata& HloInstruction::user_side_metadata() const { + return Cast<HloDomainInstruction>(this)->user_side_metadata(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 691f8155f9..5581c17c2d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -421,12 +421,6 @@ class HloInstruction { const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); - // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 - // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS - // and the RHS must be of rank 2. - static std::unique_ptr<HloInstruction> CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); - // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -866,11 +860,6 @@ class HloInstruction { return false; } - if (!absl::c_equal(precision_config_.operand_precision(), - other.precision_config_.operand_precision())) { - return false; - } - return IdenticalSlowPath(other, eq_computations); } @@ -1085,15 +1074,6 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } - // Retrieves the operand side metadata of a kDomain instruction. - const DomainMetadata& operand_side_metadata() const { - return *operand_side_metadata_; - } - // Retrieves the user side metadata of a kDomain instruction. - const DomainMetadata& user_side_metadata() const { - return *user_side_metadata_; - } - // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1101,18 +1081,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns data on the dimension numbers used for a dot operation. - const DotDimensionNumbers& dot_dimension_numbers() const { - CHECK(dot_dimension_numbers_ != nullptr); - return *dot_dimension_numbers_; - } - - // Returns the dump string of the dot dimension numbers. - string DotDimensionNumbersToString() const; - - // Returns the dump string of the precision configuration. - string PrecisionConfigToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1262,10 +1230,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfig& precision_config() const { return precision_config_; } - void set_precision_config(const PrecisionConfig& precision_config) { - precision_config_ = precision_config; - } + // Precondition: opcode must be kConvolution or kDot. + const PrecisionConfig& precision_config() const; // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1508,6 +1474,15 @@ class HloInstruction { // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Delegates to HloDotInstruction::dot_dimension_numbers(). + const DotDimensionNumbers& dot_dimension_numbers() const; + + // Delegates to HloDomainInstruction::operand_side_metadata(). + const DomainMetadata& operand_side_metadata() const; + + // Delegates to HloDomainInstruction::user_side_metadata(). + const DomainMetadata& user_side_metadata() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1647,22 +1622,12 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the dimension numbers used for a dot. - std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; - - // Used to tag kCopy instructions that are eligible for copy elision. - bool copy_elision_allowed_ = true; - // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for // many element tuples. std::shared_ptr<const HloSharding> sharding_; - // Fields used by the kDomain instruction. - std::unique_ptr<DomainMetadata> operand_side_metadata_; - std::unique_ptr<DomainMetadata> user_side_metadata_; - // Computations called by this instruction. std::vector<HloComputation*> called_computations_; @@ -1676,10 +1641,6 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; - // Information used to communicate to the implementation about the algorithm - // used to produce results. See the documentation on precision_config(). - PrecisionConfig precision_config_; - // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index ad87aa1123..fb7345a2ad 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast<PrecisionConfig::Precision>(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast<PrecisionConfig::Precision>(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction( : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction( } AppendOperand(lhs); AppendOperand(rhs); - set_precision_config(precision_config); } string HloConvolutionInstruction::ToCategory() const { @@ -1663,6 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1677,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1692,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr<HloInstruction> @@ -1702,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloConvolutionInstruction>( shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config()); + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -2161,4 +2191,113 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); } +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config_; + return proto; +} + +std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector<string> extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDotInstruction&>(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); +} + +std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique<HloDotInstruction>( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config_); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector<string> result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDomainInstruction&>(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique<HloDomainInstruction>( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e1215a7566..c3a7801164 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -957,6 +957,16 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -979,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction { Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1271,6 +1284,85 @@ class HloIotaInstruction : public HloInstruction { const int64 iota_dimension_; }; +class HloDotInstruction : public HloInstruction { + public: + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); + + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + return dot_dimension_numbers_; + } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const override; + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + + // Describes the dimension numbers used for a dot. + DotDimensionNumbers dot_dimension_numbers_; + + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; +}; + +class HloDomainInstruction : public HloInstruction { + public: + explicit HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata); + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<DomainMetadata> operand_side_metadata_; + std::unique_ptr<DomainMetadata> user_side_metadata_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 3a1bc4e328..cfe906d9c5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule( return const_cast<HloInstruction*>(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, bool uniquify_names) { @@ -198,12 +206,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const { } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } return proto; } @@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3c3371426b..26fd1b2438 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,7 @@ limitations under the License. #include <vector> #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -235,6 +237,19 @@ class HloModule { StatusOr<HloInstruction*> LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, @@ -262,6 +277,11 @@ class HloModule { static std::atomic<int> next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional<HloSchedule> schedule_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 3f1e1cc73e..68c18836eb 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -106,9 +106,6 @@ class HloModuleConfig { absl::optional<ComputationLayout> entry_computation_layout_; - // Whether this is a 'host module'. - bool is_host_module_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 4bc1bacd7d..400bd4d947 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -19,9 +19,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -30,6 +33,8 @@ namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 2105f7a349..f1dc08bafa 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector<const HloInstruction*>* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b21071c4b2..b0361c3f02 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -72,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 0f26ed4235..c54360b063 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -44,6 +45,20 @@ using absl::StrJoin; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() { return false; } + absl::optional<bool> is_scheduled; + std::unordered_map<string, AttrConfig> attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + module_ = absl::make_unique<HloModule>(name, config_); - return ParseComputations(); + if (!ParseComputations()) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK( + module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + } + + return true; } // computations ::= (computation)+ @@ -1248,11 +1279,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<string> custom_call_target; optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; + optional<int64> feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1264,6 +1298,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); + } break; } case HloOpcode::kDot: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 0dfc0a4d1c..cca50fab54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1123,18 +1123,31 @@ ENTRY Iota { )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" } )" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } - }); + +)" +} +}); // clang-format on } @@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { EXPECT_EQ(convolution->feature_group_count(), 1); } +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558..b9c0b0c4ee 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -23,11 +23,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index a65b33bf40..3fc5dbeb02 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -21,12 +21,64 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" namespace xla { +/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr<HloScheduleProto> HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + void HloSchedule::set_sequence( const HloComputation* computation, absl::Span<const HloInstruction* const> sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 21c6988638..270fe6039f 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -21,18 +21,20 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/status.h" namespace xla { +class HloModule; + // Class representing a sequence of HLO instructions such as the sequential // execution order of an HLO computation. class HloInstructionSequence { public: HloInstructionSequence() = default; - HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) { + explicit HloInstructionSequence( + absl::Span<const HloInstruction* const> instructions) { for (const HloInstruction* instruction : instructions) { push_back(instruction); } @@ -77,7 +79,12 @@ class HloInstructionSequence { // non-fusion computation in the HLO module. class HloSchedule { public: - HloSchedule(const HloModule* module) : module_(module) {} + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr<HloSchedule> CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr<HloScheduleProto> ToProto() const; // Returns a reference to the sequence for the given computation. const HloInstructionSequence& sequence( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 6e17711f57..082bf8bffe 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -855,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, ? instruction.sharding().GetSubSharding(instruction.shape(), index) : instruction.sharding(); // We propagate the sharding to the copied instruction only if it is a - // special sharding, like tiled ones, or special devices like the - // HostCompute module. + // special sharding, like tiled ones. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. auto device = sharding.UniqueDevice(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 7d49b8d6c2..a60643bc75 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -75,6 +75,16 @@ void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands, } } +void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers, + llvm::IRBuilder<>* b, llvm::Module* module) { + std::vector<llvm::Value*> buffer_ptrs; + buffer_ptrs.reserve(buffers.size()); + absl::c_transform( + buffers, std::back_inserter(buffer_ptrs), + [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); +} + llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, llvm::IRBuilder<>* b, llvm::Module* module) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 887fb61371..94340b91d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -68,6 +68,11 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands, llvm::IRBuilder<>* b, llvm::Module* module); +// Similar to EmitTuple above, except that the output buffers are provided in +// the form of IrArray. +void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers, + llvm::IRBuilder<>* b, llvm::Module* module); + // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction // forwards the pointer to underlying tuple element buffer at the given index. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 36b8fb2644..d0bda45cf8 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -75,7 +75,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_headers_lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 8c62adea23..57f7fed61f 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -866,10 +866,7 @@ INSTANTIATE_TEST_CASE_P( BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}}, BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}})); -// TODO(b/64093391) Disabled on GPU due to an assertion failure when running -// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on -// 2017-07-26. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { +XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index c20a7c8fe4..3ae31191a0 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, .status(); } +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 7790737c09..a260271b1b 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -98,6 +98,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( Status VerifyHloModule(HloModule* const module, bool layout_sensitive, bool allow_mixed_precision); +// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of +// the LHS with dimension 0 of the RHS with no batch dimensions. +// Both LHS and the RHS must be of rank 2. +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index 23ce1d235b..0c3ec5934e 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -67,8 +67,8 @@ int main(int argc, char** argv) { floats.push_back(value); } - absl::string_view content(absl::bit_cast<const char*>(floats.data()), - floats.size() * sizeof(float)); + tensorflow::StringPiece content(absl::bit_cast<const char*>(floats.data()), + floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output_file, content)); return 0; |