aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.h1
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc6
-rw-r--r--tensorflow/compiler/jit/legacy_flags/BUILD12
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc68
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h52
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc2
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h1
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc6
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h8
-rw-r--r--tensorflow/compiler/tf2xla/BUILD13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc31
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc18
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.cc67
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.h47
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h1
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc113
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h23
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc68
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h3
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h1
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i6
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc241
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc111
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h16
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto26
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h47
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc188
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h61
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc147
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h92
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h13
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc10
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h5
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc14
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h8
-rw-r--r--tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc4
78 files changed, 1580 insertions, 806 deletions
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index cf5c04ac4b..bd270045e3 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -20,6 +20,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
+#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index b95b063348..1c9d30d7b0 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@@ -92,9 +93,8 @@ Status Main(const MainFlags& flags) {
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
- TF_RETURN_IF_ERROR(
- WriteStringToFile(env, flags.out_function_object,
- absl::string_view(obj.data(), obj.size())));
+ TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object,
+ StringPiece(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
index 5b6692f523..07c5b23188 100644
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -29,18 +29,6 @@ cc_library(
)
cc_library(
- name = "parallel_check_op_flags",
- srcs = ["parallel_check_op_flags.cc"],
- hdrs = ["parallel_check_op_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "xla_device_flags",
srcs = ["xla_device_flags.cc"],
hdrs = ["xla_device_flags.h"],
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
deleted file mode 100644
index a61694b494..0000000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <mutex>
-#include <vector>
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static ParallelCheckOpFlags* flags;
-static std::vector<Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new ParallelCheckOpFlags;
- flags->parallel_check_failfast = true;
- flags->parallel_check_atol = "1e-5";
- flags->parallel_check_rtol = "1e-5";
- flag_list = new std::vector<Flag>({
- Flag("parallel_check_failfast", &flags->parallel_check_failfast,
- "Fail immediately on first parallel-check comparison error."),
- Flag("parallel_check_atol", &flags->parallel_check_atol,
- "Absolute error tolerance for parallel-check comparison."),
- Flag("parallel_check_rtol", &flags->parallel_check_rtol,
- "Relative error tolerance for parallel-check comparison."),
- });
- xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
deleted file mode 100644
index 156a2a2a71..0000000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <vector>
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
-
-// The values of flags associated with the XLA bridge's
-// parallel_check_op module.
-typedef struct {
- bool parallel_check_failfast; // Fail immediately on first parallel-check
- // comparison error.
- string parallel_check_atol; // Absolute error tolerance for parallel-check
- // comparison.
- string parallel_check_rtol; // Relative error tolerance for parallel-check
- // comparison.
-} ParallelCheckOpFlags;
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 9473ac0a4c..807ab51fd3 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope root = Scope::NewRootScope().ExitOnError();
{
- auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
+ auto BuildNoopNode = [](StringPiece name, Graph* graph) {
NodeDefBuilder builder(name, "NoOp");
NodeDef def;
TF_CHECK_OK(builder.Finalize(&def));
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index 17ae510a0e..debd9038c7 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index af83c792e5..6d4160a968 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
- absl::string_view tensor_name,
+ StringPiece tensor_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) {
- manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
- done);
+ manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name),
+ device, cpu_tensor, done);
}
void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index df82421294..1effd6628f 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
@@ -110,9 +111,12 @@ class XlaDeviceContext : public DeviceContext {
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
StatusCallback done) const override;
+ // TODO(rlahaye): Replace StringPiece with absl::string_view when the
+ // StringPiece->absl::string_view change is rolled forward.
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
- absl::string_view tensor_name, Device* device,
- Tensor* cpu_tensor, StatusCallback done) override;
+ StringPiece tensor_name, // non-ABSL OK
+ Device* device, Tensor* cpu_tensor,
+ StatusCallback done) override;
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 22be7f048f..3821dced63 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -191,6 +191,7 @@ cc_library(
":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
+ ":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
@@ -214,6 +215,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
@@ -359,6 +361,7 @@ tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
deps = [
+ ":side_effect_util",
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
@@ -370,6 +373,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:core_cpu_internal",
@@ -631,3 +635,12 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "side_effect_util",
+ srcs = ["side_effect_util.cc"],
+ hdrs = ["side_effect_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 4c776fb178..46794f7b50 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -115,9 +115,6 @@ tf_kernel_library(
deps = [
":if_op",
":while_op",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
@@ -168,14 +165,11 @@ tf_kernel_library(
"//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops",
- ] + if_mkl(
- [
- "//tensorflow/core/kernels:mkl_transpose_op",
- ],
- [
- "//tensorflow/core/kernels:transpose_op",
- ],
- ),
+ "//tensorflow/core/kernels:transpose_op",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
)
tf_kernel_library(
@@ -184,6 +178,7 @@ tf_kernel_library(
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
@@ -201,6 +196,7 @@ tf_kernel_library(
hdrs = ["if_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6e1dbf5472..56da50f140 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
// TODO(b/35949885): There is duplication here with the handling of the
@@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
+ options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
XlaCompiler::CompilationResult then_result;
@@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "if" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(b, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
+ if (has_token_input_output_) {
+ // Set token output for this "if" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(outputs, output_types_.size());
+ auto shape_or = b->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the conditional
// bodies.
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h
index f9bc98a198..7783e13a8a 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.h
@@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel {
DataType cond_type_;
DataTypeVector input_types_;
DataTypeVector output_types_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 296518229e..559414eeaa 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
@@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
body_options.return_updated_values_for_all_resources = true;
body_options.resolve_compile_time_constants = false;
body_options.is_entry_computation = false;
+ body_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
cond_options.use_tuple_arg = true;
cond_options.resolve_compile_time_constants = false;
cond_options.is_entry_computation = false;
+ cond_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
@@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "while" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(builder, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
@@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::GetTupleElement(while_result, i));
}
}
+ if (has_token_input_output_) {
+ // Set token output for this "while" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(while_result, ctx->num_outputs());
+ auto shape_or = builder->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h
index 67edebabf9..aeeff40e68 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.h
@@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel {
private:
NameAttrList cond_name_attr_;
NameAttrList body_name_attr_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 20f2ce2919..92577b5bc8 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
@@ -30,11 +31,10 @@ namespace tensorflow {
}
}
-static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
-CreateResourceOpInfoMap() {
- auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
+static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
+ auto* result = new gtl::FlatMap<StringPiece, XlaResourceOpInfo>;
- auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
+ auto add = [&](StringPiece op, XlaResourceOpKind op_kind,
XlaResourceKind resource_kind) {
auto insert_result =
result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
@@ -103,17 +103,17 @@ CreateResourceOpInfoMap() {
return result;
}
-static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
+static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap() {
- static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
+ static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map =
CreateResourceOpInfoMap();
return *op_info_map;
}
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
- const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
+ const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos =
GetStaticResourceOpInfoMap();
- auto it = op_infos.find(op);
+ auto it = op_infos.find(StringPiece(op.data(), op.length()));
return it == op_infos.end() ? nullptr : &it->second;
}
@@ -121,7 +121,7 @@ namespace resource_op_table_internal {
std::vector<absl::string_view> GetKnownResourceOps() {
std::vector<absl::string_view> result;
for (const auto& p : GetStaticResourceOpInfoMap()) {
- result.push_back(p.first);
+ result.push_back(absl::string_view(p.first));
}
absl::c_sort(result);
return result;
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
new file mode 100644
index 0000000000..6cd7b24592
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
+
+const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
+
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
+ std::set<std::string> results;
+ Node* first_side_effecting_node_on_path = nullptr;
+ ReverseDFS(g,
+ [&](Node* n) {
+ std::vector<string> token_input_nodes;
+ if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
+ &token_input_nodes)
+ .ok() ||
+ token_input_nodes.empty()) {
+ return;
+ }
+
+ if (first_side_effecting_node_on_path != nullptr) {
+ return;
+ }
+
+ first_side_effecting_node_on_path = n;
+ results.insert(n->name());
+ },
+ [&](Node* n) {
+ if (first_side_effecting_node_on_path == n) {
+ first_side_effecting_node_on_path = nullptr;
+ }
+ },
+ NodeComparatorName());
+ return results;
+}
+
+bool HasSideEffectingNodes(const Graph& g) {
+ for (Node* n : g.nodes()) {
+ std::vector<string> token_input_nodes;
+ if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
+ .ok() &&
+ !token_input_nodes.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
new file mode 100644
index 0000000000..ad07624729
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Side-effecting nodes will have this attribute set. Its value is the list of
+// node names which this node has side-effect dependencies on.
+//
+// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
+// because they always have side-effect.
+// If and While nodes may or may not have this attribute, depending on whether
+// their bodies have side-effecting nodes.
+extern const char kXlaTokenInputNodesAttrName[];
+
+// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
+// node has side-effect dependency on current graph's token input.
+extern const char kXlaTokenArgNodeName[];
+
+// Calculates side-effect dependencies for the graph's token output.
+// Returns a set of node names representing these dependencies.
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
+
+// Returns whether a graph contains side-effecting nodes.
+bool HasSideEffectingNodes(const Graph& g);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index a29e764466..dcddef8418 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 41d305d461..dcb455779d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
@@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
"Invalid resource type in XLAShapeForArgument()");
}
}
+ case XlaCompiler::Argument::kToken: {
+ *xla_shape = xla::ShapeUtil::MakeTokenShape();
+ return Status::OK();
+ }
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
}
@@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments(
}
break;
- case XlaCompiler::Argument::kParameter: {
+ case XlaCompiler::Argument::kParameter:
+ case XlaCompiler::Argument::kToken: {
input_mapping->push_back(i);
break;
}
@@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_handle(arg_handles[i]);
}
break;
+ case XlaCompiler::Argument::kToken: {
+ arg_expression.set_handle(arg_handles[i]);
+ break;
+ }
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
return errors::Internal(
@@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
+ std::vector<XlaCompiler::Argument> real_args(args);
+ int token_input_index = -1;
+ if (options.add_token_input_output) {
+ // Add extra token input.
+ token_input_index = real_args.size();
+
+ XlaCompiler::Argument token_arg;
+ token_arg.kind = XlaCompiler::Argument::kToken;
+ real_args.push_back(token_arg);
+ }
+
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
- TF_RETURN_IF_ERROR(
- BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
- &arg_cores, &arg_expressions, &result->input_mapping,
- &result->xla_input_shapes, options.is_entry_computation));
+ TF_RETURN_IF_ERROR(BuildArguments(
+ *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
+ &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
+ options.is_entry_computation));
context->set_args(std::move(arg_expressions));
+ PushNodeTokenMapping();
+ // Use std::set instead of std::unordered_set to ensure determinism.
+ std::set<std::string> output_node_token_inputs;
+ if (token_input_index != -1) {
+ // Original token comes from input.
+ auto arg_expression = context->args()[token_input_index];
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
+
+ // Calculate token inputs for output token.
+ output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
+
+ // If there's no side-effecting op in the graph, use token input as token
+ // output.
+ if (output_node_token_inputs.empty()) {
+ output_node_token_inputs.insert(kXlaTokenArgNodeName);
+ }
+ } else if (options.is_entry_computation) {
+ // Original token is manually created.
+ if (HasSideEffectingNodes(*graph)) {
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
+ }
+ }
+
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_, NextStepId()));
+ if (token_input_index != -1) {
+ // Add extra token output.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const auto& node_name : output_node_token_inputs) {
+ auto token_or = GetNodeToken(node_name);
+ TF_RETURN_IF_ERROR(token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ TF_RETURN_IF_ERROR(
+ context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
+ }
+ TF_RETURN_IF_ERROR(PopNodeTokenMapping());
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
- args, arg_cores, context->retvals(), context->resources(),
+ real_args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
@@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency(
return Status::OK();
}
+void XlaCompiler::PushNodeTokenMapping() {
+ node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
+}
+
+Status XlaCompiler::PopNodeTokenMapping() {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ node_token_mapping_stack_.pop();
+ return Status::OK();
+}
+
+Status XlaCompiler::SetNodeToken(const string& node_name,
+ const xla::XlaOp& op) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling SetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
+ if (!insert_result.second) {
+ return errors::FailedPrecondition("Token mapping already exists for node ",
+ node_name);
+ }
+ return Status::OK();
+}
+
+xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling GetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto iter = node_token_mapping_stack_.top().find(node_name);
+ if (iter == node_token_mapping_stack_.top().end()) {
+ return errors::FailedPrecondition("Cannot find token mapping for node ",
+ node_name);
+ }
+ return iter->second;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 8f4a9858ed..2cc603a580 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#include <stack>
+
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
@@ -106,6 +109,9 @@ class XlaCompiler {
// Argument is a run-time parameter.
kParameter,
+
+ // Argument is an XLA token.
+ kToken,
};
Kind kind = kInvalid;
@@ -179,6 +185,9 @@ class XlaCompiler {
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
+
+ // True when we should add XLA input & output to the graph/function.
+ bool add_token_input_output = false;
};
struct OutputDescription {
@@ -384,6 +393,11 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
+ void PushNodeTokenMapping();
+ Status PopNodeTokenMapping();
+ Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
+ xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
+
private:
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
@@ -448,6 +462,15 @@ class XlaCompiler {
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
+ // This is used to store <node name, token output> mapping. Side-effecting
+ // ops call SetNodeToken() to record its token output, so later side-effecting
+ // ops can use GetNodeToken() to get it and use it as token input.
+ //
+ // It's a stack because we need a mapping like this for each level of nested
+ // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
+ // stack, and pop the mapping before returning.
+ std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index be3c93ae47..40ce9fb41c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -20,10 +20,12 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
}
}
+class DummySideEffectingOp : public XlaOpKernel {
+ public:
+ explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
+ name(), xla::CreateToken(ctx->builder())));
+ }
+};
+
+REGISTER_OP("DummySideEffectingOp");
+
+REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
+
+TEST_F(XlaCompilerTest, TokenInputAndOutput) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef side_effecting_op;
+ side_effecting_op.set_name("DummySideEffectingOp");
+ side_effecting_op.set_op("DummySideEffectingOp");
+ AddNodeAttr(kXlaTokenInputNodesAttrName,
+ std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
+ Status status;
+ graph->AddNode(side_effecting_op, &status);
+ TF_ASSERT_OK(status);
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
+
+ const std::vector<XlaCompiler::Argument> empty_args;
+ {
+ // The case for entry computation: we don't add token input/output. Instead,
+ // we use CreateToken HLO to create the entry token.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = true;
+ options.add_token_input_output = false;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 0);
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
+ }
+ {
+ // The case for non-entry computation (e.g. while loop body). We add token
+ // input/output.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = false;
+ options.add_token_input_output = true;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(
+ xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index e8b4b0eb36..f247570d72 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
return Status::OK();
}
+Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
+ VLOG(1) << "Adding retval index " << retvals_.size()
+ << " with token to XLA computation";
+ XlaExpression e;
+ e.set_handle(token);
+ // We use DT_INVALID because there is no TF DataType which corresponds to XLA
+ // token. XlaCompiler handles this case separately, so putting it here is OK.
+ retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
+ return Status::OK();
+}
+
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 4da891634e..d7dbdc957f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -89,6 +89,9 @@ class XlaContext : public ResourceBase {
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
+ // As for Retval, but for return values that are XLA tokens.
+ Status AppendTokenRetval(const xla::XlaOp& token);
+
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index d67e50375b..636cb71e21 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -102,7 +102,8 @@ Status XlaOpKernelContext::ConstantInput(int index,
static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
absl::string_view name) {
int start, stop;
- TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(
+ StringPiece(name.data(), name.length()), &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
@@ -365,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
- TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
+ TF_RETURN_IF_ERROR(
+ context_->input_list(StringPiece(name.data(), name.size()), &inputs));
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
@@ -378,7 +380,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
Status XlaOpKernelContext::ConstantInputList(
absl::string_view name, std::vector<xla::Literal>* outputs) {
int start, stop;
- TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(op_kernel().InputRange(
+ StringPiece(name.data(), name.size()), &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
@@ -612,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
- CHECK(context_->input(name, &tensor).ok());
+ CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok());
return *tensor;
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 74a4885f1f..5d53169f68 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index f9473d372b..bddb664149 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -64,7 +65,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
absl::Span<const float> field = result->data<float>();
char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- absl::string_view sp;
+ tensorflow::StringPiece sp;
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -85,7 +86,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- absl::string_view sp;
+ tensorflow::StringPiece sp;
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 76c09512d8..450d3fe5af 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -109,12 +109,12 @@ limitations under the License.
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
-#include "third_party/absl/strings/str_cat.h"
-#include "third_party/absl/strings/str_format.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "third_party/absl/types/span.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ab86dce510..e784663ff6 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -159,6 +159,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
@@ -291,6 +292,7 @@ cc_library(
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
+ "hlo_schedule.cc",
"hlo_sharding.cc",
],
hdrs = [
@@ -303,6 +305,7 @@ cc_library(
"hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
+ "hlo_schedule.h",
"hlo_sharding.h",
],
deps = [
@@ -331,6 +334,8 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@@ -1037,7 +1042,6 @@ tf_cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -1065,7 +1069,6 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
- ":hlo_schedule",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1086,7 +1089,6 @@ tf_cc_test(
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1108,7 +1110,6 @@ cc_library(
":hlo",
":hlo_ordering",
":hlo_proto",
- ":hlo_schedule",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1177,22 +1178,6 @@ cc_library(
],
)
-cc_library(
- name = "hlo_schedule",
- srcs = ["hlo_schedule.cc"],
- hdrs = ["hlo_schedule.h"],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib_internal",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- ],
-)
-
tf_cc_test(
name = "hlo_schedule_test",
srcs = ["hlo_schedule_test.cc"],
@@ -1202,7 +1187,6 @@ tf_cc_test(
":hlo_dce",
":hlo_ordering",
":hlo_parser",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1222,7 +1206,6 @@ cc_library(
":heap_simulator",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1969,6 +1952,8 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
+ ":hlo_matchers",
+ ":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1977,6 +1962,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -2413,7 +2399,6 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
@@ -2587,6 +2572,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index aa40fba9bb..a0db4563fb 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2369,20 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
rhs_pad->shape().dimensions(3),
testcase.orig_conv_window))
.ValueOrDie();
- auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
- /*feature_group_count=*/1, window,
- dnums)
- .ValueOrDie(),
- input, rhs_pad, /*feature_group_count=*/1, window, dnums,
- DefaultPrecisionConfig(2)));
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
PrecisionConfig precision_config;
precision_config.add_operand_precision(PrecisionConfig::HIGH);
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
- orig_conv->set_precision_config(precision_config);
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+ /*feature_group_count=*/1, window,
+ dnums)
+ .ValueOrDie(),
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ precision_config));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -2401,7 +2401,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
conv->operand(1)->shape().dimensions(2),
conv->operand(1)->shape().dimensions(3),
testcase.expected_conv_window));
- EXPECT_THAT(conv->precision_config().operand_precision(),
+ EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
+ ->precision_config()
+ .operand_precision(),
ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
}
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 69b654d30e..388fd5df99 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16PropagationTest : public HloTestBase {
+class BFloat16PropagationTest : public HloVerifiedTestBase {
protected:
+ BFloat16PropagationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
// Runs the propagation pass on the given module, and returns whether the
// module is changed after this pass.
bool PropagatePrecision(HloModule* module) {
@@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase {
inst->users()[0]->opcode() == HloOpcode::kConvert &&
inst->users()[0]->shape().element_type() == BF16;
}
+
+ std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ DefaultPrecisionConfig(2));
+ }
};
// Tests that BF16 can propagate through select over non-tuple buffers, but not
@@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
- HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b));
+ HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
HloInstruction* sel = builder.AddInstruction(
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
HloInstruction* xpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a));
- HloInstruction* root = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
+ HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
EXPECT_TRUE(OutputsBF16(xpose));
@@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(dot->operand(0)));
@@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple0->shape(), tuple1, 0)),
0));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
HloInstruction* output_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
@@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), output_tuple);
EXPECT_TRUE(OutputsBF16(xpose));
@@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
// lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(add1));
@@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
// Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
@@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
@@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_FALSE(OutputsBF16(add));
@@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
@@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
HloInstruction* b_f1 = builder_f1.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
- HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1));
+ HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), fusion1);
EXPECT_TRUE(OutputsBF16(add));
@@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
@@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
HloInstruction* add_f = builder_f.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
- HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f));
+ HloInstruction* dot_f =
+ builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), fusion);
}
@@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
HloInstruction::CreateGetTupleElement(shape, fusion, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, fusion, 1));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(gte0));
@@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) {
HloInstruction* xpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_FALSE(OutputsBF16(add0));
@@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
auto builder_cond = HloComputation::Builder("cond");
auto cond_param = builder_cond.AddInstruction(
HloInstruction::CreateParameter(0, shape, "cond_param"));
- auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond_param, cond_param));
+ auto cond_dot =
+ builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
auto body_param = builder_body.AddInstruction(
HloInstruction::CreateParameter(0, shape, "body_param"));
- auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, body_param, body_param));
+ auto body_dot =
+ builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
auto body = module->AddEmbeddedComputation(builder_body.Build());
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(shape, cond, body, add));
- auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(
@@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest,
HloInstruction::CreateParameter(0, shape, "cond_param"));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
+ {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest,
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(shape, cond, body, add));
- auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_FALSE(OutputsBF16(add));
EXPECT_FALSE(OutputsBF16(body_fusion));
@@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
// This add should prevent RHS from using BF16
auto cond_add_rhs = builder_cond.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
- auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond_lhs, cond_add_rhs));
+ auto cond_dot =
+ builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot1 = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
- auto body_dot2 = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs));
+ auto body_dot1 =
+ builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
+ auto body_dot2 =
+ builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
auto body_transpose = builder_body.AddInstruction(
HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
builder_body.AddInstruction(
@@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
auto rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
- auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(lhs));
@@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto cond0_add_rhs =
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
- auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs));
+ auto cond0_dot =
+ builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond0.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond0.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond0.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
// Condition computation for the second while.
@@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto cond1_add_lhs =
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
- auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs));
+ auto cond1_dot =
+ builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond1.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond1.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond1.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
// Body computation shared by both whiles.
@@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+ auto body_dot =
+ builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
builder_body.AddInstruction(
HloInstruction::CreateTuple({body_dot, body_rhs}));
auto body = module->AddEmbeddedComputation(builder_body.Build());
@@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
- auto lhs = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot,
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while0, 0)),
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while0, 1))));
- auto rhs = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot,
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while1, 0)),
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while1, 1))));
- auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto lhs = builder.AddInstruction(
+ CreateDot(shape,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 1))));
+ auto rhs = builder.AddInstruction(
+ CreateDot(shape,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 1))));
+ auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_FALSE(OutputsBF16(body_dot));
EXPECT_FALSE(OutputsBF16(body_rhs));
EXPECT_FALSE(OutputsBF16(body_lhs));
@@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), add2);
EXPECT_EQ(add2->operand(0), add0);
@@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
HloInstruction::CreateGetTupleElement(shape, domain, 0));
HloInstruction* b_gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, domain, 1));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
// test BF16 propagated through domain
@@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
HloInstruction* b_trans = builder.AddInstruction(
HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+ HloInstruction* dot =
+ builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
EXPECT_TRUE(OutputsBF16(a_trans));
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index d412578619..2368ac8c6a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -670,6 +670,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 0fea462c85..7d99b914d4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace op = xla::testing::opcode_matchers;
@@ -696,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
auto* addend = builder.AddInstruction(
HloInstruction::CreateParameter(2, dot_shape, "param2"));
- auto* dot = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ auto* dot =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
builder.AddInstruction(
HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 9363af3b89..4668f3872d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
builder.AddInstruction(HloInstruction::CreateBinary(
result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
@@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
auto tuple_result = builder.AddInstruction(
HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
@@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, rhs_shape, "param0"));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
HloInstruction::CreateParameter(1, dot_shape, "param1"));
HloInstruction* dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
- HloInstruction* dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ HloInstruction* dot_result =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
HloInstruction* add_result;
if (dot_operand_idx_in_add == 0) {
add_result = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index a84ee78b19..fad76338a5 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -35,9 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false),
- target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 2384166fd2..f11aff0573 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -121,6 +121,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index fcd87b36b3..18ee25ba91 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) {
HloInstruction* rhs = builder.AddInstruction(
HloInstruction::CreateParameter(1, param_shape, "input"));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
@@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) {
HloInstruction* lhs_transposed = builder.AddInstruction(
HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 13ccff35f8..6791e15ee0 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -480,6 +481,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -813,7 +815,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_schedule",
"//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
@@ -831,6 +832,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 0922e44a12..59ade96f7d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
- HloInstruction* add = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* add =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(add));
@@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) {
i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index bca775c475..96bfe0c12e 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index ffca5d6549..b7c37bcf3c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -764,5 +764,20 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
return Load(return_buffer);
}
+std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo) {
+ std::vector<llvm_ir::IrArray> output_arrays;
+ if (ShapeUtil::IsTuple(hlo.shape())) {
+ int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays.reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays.push_back(GetIrArray(hlo, hlo));
+ }
+ return output_arrays;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 579268f071..8805201480 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -124,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
return bindings_.GetBasePointer(inst);
}
+
+ // Generates the IrArray for each output of an hlo instruction and returns
+ // a vector containing such IrArrays.
+ std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
+ const HloInstruction& hlo);
+
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
BufferAllocation::Slice GetAllocationSlice(
const HloInstruction& hlo, const ShapeIndex& index = {}) const {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
index 5c827e5f9c..66c65f6975 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
@@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop(
// For MOF we give the loop emitter an array for every output it should
// generate.
if (hlo.IsMultiOutputFusion()) {
- const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape());
- std::vector<llvm_ir::IrArray> target_arrays;
- target_arrays.reserve(num_elems);
- for (int64 i = 0; i != num_elems; ++i) {
- target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<llvm_ir::IrArray> target_arrays =
+ ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
-
- std::vector<llvm::Value*> tuple_operand_ptrs;
- tuple_operand_ptrs.reserve(num_elems);
- for (const llvm_ir::IrArray& array : target_arrays) {
- tuple_operand_ptrs.push_back(array.GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_);
return Status::OK();
}
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 389a98facb..f91cc00d71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2521,15 +2521,15 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index) {
+ HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
- const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
- const HloInstruction* init_value_operand = [&] {
+ HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
+ HloInstruction* init_value_operand = [&] {
switch (inst->opcode()) {
case HloOpcode::kSelectAndScatter:
- return inst->operand(2);
+ return inst->mutable_operand(2);
case HloOpcode::kReduce:
- return inst->operand(1);
+ return inst->mutable_operand(1);
case HloOpcode::kTuple:
CHECK(hlo->IsMultiOutputFusion())
<< ": " << hlo->ToString() << " is not a multi-output fusion.";
@@ -2537,7 +2537,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
<< ": Found '" << inst->operand(index.back())->opcode() << "' in "
<< inst->ToString() << " but expected 'reduce'.";
// For multi-output fusion look through the tuple.
- return inst->operand(index.back())->operand(1);
+ return inst->mutable_operand(index.back())->mutable_operand(1);
default:
LOG(FATAL) << "Opcode " << inst->opcode()
<< " should not need an initializer.";
@@ -2609,28 +2609,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
- // If the init_value was fused into this reduce we have to generate it first.
- if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
- CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
- const Literal& literal = init_value_operand->literal();
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+ if (fused) {
+ // If init_value was fused into this reduce we have to generate it first.
+ std::vector<IrArray> parameter_arrays;
+ for (HloInstruction* operand : hlo->operands()) {
+ parameter_arrays.push_back(GetIrArray(*operand, *hlo));
+ }
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &b_, GetNestedComputer());
- llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
- *module_, initializer->getType(),
- /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
- /*Name=*/"");
- global_for_const->setAlignment(kConstantBufferAlignBytes);
- bindings_.BindHloToIrValue(*init_value_operand, global_for_const);
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
+ TF_RETURN_IF_ERROR(
+ ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
+ } else {
+ // In the unfused case the element is already there, just read from it.
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+ [=](const IrArray::Index& index) {
+ return GetIrArray(*init_value, *hlo)
+ .EmitReadArrayElement(index, &b_);
+ },
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
}
- TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const IrArray::Index& index) {
- return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &b_);
- },
- GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
- .EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
// done in IrEmitterUnnested::Postprocess().)
@@ -2819,10 +2826,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
}
// For multioutput fusion, we need to emit each operand and the root.
- std::vector<IrArray> output_arrays;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
- output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
&b_, unroll_factor)
@@ -2830,12 +2834,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
GetIndexTypeForKernel(
&hlo, launch_dimensions.launch_bound(), &b_)));
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
+
return Status::OK();
}
@@ -2847,29 +2848,14 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
-int IrEmitterUnnested::ConstructIrArrayForOutputs(
- const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
- int64 num_outputs = 1;
- if (hlo.IsMultiOutputFusion()) {
- num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
- output_arrays->reserve(num_outputs);
- for (int64 i = 0; i < num_outputs; ++i) {
- output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
- }
- } else {
- output_arrays->push_back(GetIrArray(hlo, hlo));
- }
- return num_outputs;
-}
-
-int IrEmitterUnnested::ConstructIrArrayForInputs(
- const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
- int64 num_params = hlo.operands().size();
- param_arrays->reserve(num_params);
+std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo) {
+ std::vector<IrArray> param_arrays;
+ param_arrays.reserve(hlo.operands().size());
for (const HloInstruction* param : hlo.operands()) {
- param_arrays->push_back(GetIrArray(*param, hlo));
+ param_arrays.push_back(GetIrArray(*param, hlo));
}
- return num_params;
+ return param_arrays;
}
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
@@ -3050,10 +3036,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
// Construct IrArrays for the inputs and outputs.
- std::vector<IrArray> output_arrays;
- int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
- std::vector<IrArray> param_arrays;
- int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
+ int64 num_outputs = output_arrays.size();
+ std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo);
+ int64 num_params = param_arrays.size();
// Allocate shared memory buffers to store the tiled inputs.
std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
@@ -3251,12 +3237,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// For multioutput fusion, emit a tuple with all the individual outputs.
if (hlo->IsMultiOutputFusion()) {
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
- module_);
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_);
}
return launch_dimensions;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 084462330e..bd5db72051 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -193,14 +193,12 @@ class IrEmitterUnnested : public IrEmitter {
LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids);
- // Generates the IrArray for each output of hlo and returns the number of
- // outputs.
- int ConstructIrArrayForOutputs(const HloInstruction& hlo,
- std::vector<llvm_ir::IrArray>* output_arrays);
- // Generates the IrArray for each input of hlo and returns the number of
- // inputs.
- int ConstructIrArrayForInputs(const HloInstruction& hlo,
- std::vector<llvm_ir::IrArray>* param_arrays);
+
+ // Generates the IrArray for each input of an hlo and returns a vector that
+ // constains such IrArrays.
+ std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
+ const HloInstruction& hlo);
+
// For each output of the `hlo` instruction, constructs the reduced shape for
// the output with the given `reduced_output_dims` and cast the original
// output IrArray element in `output_arrays` to the reduced shape. Returns
@@ -244,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a thunk that, given a reduce or select-and-scatter op, initializes
// its memory to the appropriate initial value.
StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index = {});
+ HloInstruction* hlo, const ShapeIndex& index = {});
// Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 091aca23e5..8f0dedfa40 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
@@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 99d0cf50ca..93ec2c9438 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -199,6 +199,17 @@ message HloComputationProto {
int64 root_id = 6;
}
+// Serialization of an HLO schedule. An HLO schedule contains a total order of
+// instructions for each non-fusion computation in the module.
+message HloScheduleProto {
+ message InstructionSequence {
+ repeated int64 instruction_ids = 1;
+ }
+
+ // Map from computation id to sequence.
+ map<int64, InstructionSequence> sequences = 1;
+}
+
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@@ -214,16 +225,9 @@ message HloModuleProto {
// The id of this module.
int64 id = 5;
-}
-// Serialization of HloOrdering.
-message HloOrderingProto {
- // NOTE: currently only sequential orderings are serialized.
- message SequentialComputation {
- string computation_name = 1;
- repeated string instruction_names = 2;
- }
- repeated SequentialComputation sequential_computations = 1;
+ // The schedule for this module.
+ HloScheduleProto schedule = 7;
}
// Serialization of LogicalBuffer.
@@ -322,8 +326,10 @@ message BufferAssignmentProto {
// Grouping message that contains all of the information above.
message HloProto {
+ reserved 2;
+ reserved "hlo_ordering";
+
HloModuleProto hlo_module = 1;
- HloOrderingProto hlo_ordering = 2;
BufferAssignmentProto buffer_assignment = 3;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index fe7f2be888..233d2199d1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
}
string HloComputation::ToString(const HloPrintOptions& options) const {
+ return ToString(options, MakeInstructionPostOrder());
+}
+
+string HloComputation::ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const {
+ CHECK_EQ(instruction_order.size(), instruction_count());
+
std::ostringstream s;
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
@@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
new_options.set_indent_amount(options.indent_amount() + 1)
.set_is_in_nested_computation(true);
CanonicalNameMap name_map;
- for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ for (const HloInstruction* instruction : instruction_order) {
+ CHECK_EQ(this, instruction->parent());
+
for (int i = 0; i < new_options.indent_amount(); i++) {
s << " ";
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index fe2d3bbbe5..91c5234a6f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -170,6 +170,11 @@ class HloComputation {
string ToString() const { return ToString(HloPrintOptions()); }
string ToString(const HloPrintOptions& options) const;
+ // Overload which accepts an order to emit the instructions in.
+ string ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const;
+
// Returns a serialized representation of this computation.
HloComputationProto ToProto() const;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 939b5114c3..a502fff9a0 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
+ // Domain does not have any computation or data transfer.
+ current_should_compute_bottleneck_time_ = false;
+ current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
@@ -507,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
valid_position_counts.push_back(valid_position_count);
}
- const int64 fma_count =
- input_feature * output_feature * batch * Product(valid_position_counts);
+ const int64 fma_count = (input_feature / convolution->feature_group_count()) *
+ output_feature * batch *
+ Product(valid_position_counts);
current_properties_[kFlopsKey] = fma_count * kFmaFlops;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 9bb3f12ee2..46b4bbeef2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override;
+ Status HandleDomain(const HloInstruction* domain) override;
Status HandleDot(const HloInstruction* dot) override;
Status HandleConvolution(const HloInstruction* convolution) override;
Status HandleFft(const HloInstruction* fft) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 2c854eea18..d76ce9ecbc 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) {
sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
}
+TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
+ XlaBuilder builder("convolution");
+ auto input = Parameter(
+ &builder, 0,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = Parameter(
+ &builder, 1,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+ Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Output shape is [1x120x8x18] and each output element requires (3x3)
+ // FMAs and one FMA is 2 flops.
+ EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
+
+ // Bytes accessed is sum of inputs and output.
+ EXPECT_EQ(analysis.bytes_accessed(),
+ sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
+}
+
TEST_F(HloCostAnalysisTest, Reduce) {
XlaBuilder builder("reduce");
auto input =
@@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) {
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
- XlaBuilder builder("matmul");
+ XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
@@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
+using DomainCostAnalysis = HloTestBase;
+TEST_F(DomainCostAnalysis, DomainCost) {
+ HloCostAnalysis analysis(ShapeSize);
+
+ HloComputation::Builder builder("domain");
+ auto x = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {123}), "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
+ auto domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
+ ASSERT_IS_OK(domain->Accept(&analysis));
+
+ EXPECT_EQ(analysis.flop_count(*domain), 0);
+ EXPECT_EQ(analysis.transcendental_count(*domain), 0);
+ EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
+}
+
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
XlaBuilder builder("BaseDilatedConvolution");
auto input = Parameter(
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 406d712ec6..e09d5868f2 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class HloCseTest : public HloTestBase {
+class HloCseTest : public HloVerifiedTestBase {
protected:
HloCseTest() {}
};
@@ -65,13 +65,13 @@ TEST_F(HloCseTest, CombineTwoConstants) {
EXPECT_EQ(3, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
HloInstruction* constant = *computation->instructions().begin();
EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -96,14 +96,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
auto first_operand = add->operand(0);
EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2));
EXPECT_THAT(add, op::Add(first_operand, first_operand));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -128,12 +128,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
EXPECT_THAT(add, op::Add(constant1, constant2));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
EXPECT_EQ(20, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
// CSE will remove both the second float(42.0f) and the corresponding
// convert/cast.
@@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) {
op::Tuple(common_constant1, common_constant2, uncommon_constant));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
// Test two identical while loops with same inputs
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
condition=%condition.1, body=%body
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
}
// Test two while loops with same conditions, same inputs, but different
// bodies
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
f32[]) %tuple.1), condition=%condition.1, body=%body2
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
}
// Test two identical while loops with different inputs
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
condition=%condition.1, body=%body
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(8, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(8, computation->instruction_count());
}
// Test two while loops with identical bodies and same inputs, but different
// conditions
TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
%while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
f32[]) %tuple.1), condition=%condition.1, body=%body
- })")
- .ValueOrDie();
+ })");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
}
@@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
@@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) {
EXPECT_EQ(5, fused_computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(4, fused_computation->instruction_count());
auto root = fused_computation->root_instruction();
@@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) {
EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
auto operand = tuple->operand(0);
@@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) {
uint32 count_before = computation->instruction_count();
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
uint32 count_after = computation->instruction_count();
EXPECT_EQ(count_before, count_after);
@@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
VLOG(3) << "before: " << module->ToString();
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
VLOG(3) << "after: " << module->ToString();
@@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
}
TEST_F(HloCseTest, CompareComputations) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule m
add_computation {
@@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) {
r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
ROOT f2 = (f32[],f32[]) tuple(r1, r2)
- })")
- .ValueOrDie();
+ })");
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+ HloInstruction* root = module().entry_computation()->root_instruction();
EXPECT_EQ(root->operand(0), root->operand(1));
}
@@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
EXPECT_EQ(2, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
}
TEST_F(HloCseTest, Domain) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule module
ENTRY %entry {
%param = f32[] parameter(0), sharding={maximal device=0}
@@ -735,13 +730,11 @@ ENTRY %entry {
domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
%add = f32[] add(%domain.3, %domain.4)
ROOT %sub = f32[] subtract(%add, %domain.5)
-})")
- .ValueOrDie();
+})");
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- LOG(INFO) << "AAAAA " << module->ToString();
- const HloInstruction* sub = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+ const HloInstruction* sub = module().entry_computation()->root_instruction();
const HloInstruction* add = sub->operand(0);
EXPECT_EQ(add->operand(0), add->operand(1));
EXPECT_NE(add->operand(0), sub->operand(1));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index abd4bb1f73..102ebb24ab 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -52,10 +52,7 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false),
- use_bfloat16_(GetParam()) {
+ HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) {
evaluator_ = absl::make_unique<HloEvaluator>();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 6a09bb08f4..63303aef1e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1052,7 +1052,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
rhs_literal_data,
- feature_group_count](absl::Span<const int64> out_index) {
+ feature_group_count](const absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1063,9 +1063,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_batch_dim = dnums.output_batch_dimension();
const int64 output_z_dim = dnums.output_feature_dimension();
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 input_z_size =
+ ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ // The size of an input feature group.
+ const int64 input_feature_group_size = input_z_size / feature_group_count;
+
const int64 output_z_size =
ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
+ // The output feature dimension is a concatenation of convolution results
+ // from the different groups.
+ const int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ const int64 feature_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1073,33 +1086,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
- for (int64 iz = 0; iz < z_size; ++iz) {
- int64 rhs_iz = iz;
- // Handle grouped convolutions.
- if (feature_group_count > 1) {
- // The size of a feature group.
- int64 feature_group_size = z_size / feature_group_count;
- rhs_iz = iz % feature_group_size;
-
- // The output feature dimension is a concatenation of convolution
- // results from the different groups.
- int64 output_feature_group_size =
- output_z_size / feature_group_count;
-
- // Calculate the group index to which the current input feature
- // index belongs.
- int64 input_group_index = iz / feature_group_size;
-
- // Calculate the group index to which the current output index
- // belongs.
- int64 output_group_index =
- out_index[output_z_dim] / output_feature_group_size;
- if (input_group_index != output_group_index) {
- // If the current output index does not belong to the current
- // feature group, skip it.
- continue;
- }
- }
+ for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
+ const int64 iz =
+ feature_group_index * input_feature_group_size + rhs_iz;
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 471a12d6aa..25ae344ea5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -451,6 +451,28 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
+ case HloOpcode::kDot: {
+ TF_RET_CHECK(proto.has_dot_dimension_numbers())
+ << "Dot instruction should have dot_dimension_numbers.";
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Dot instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ PrecisionConfig precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfig::DEFAULT);
+ instruction = absl::make_unique<HloDotInstruction>(
+ proto.shape(), operands(0), operands(1),
+ proto.dot_dimension_numbers(), precision_config);
+ break;
+ }
+ case HloOpcode::kDomain:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Domain instruction should have 1 operands but sees "
+ << proto.operand_ids_size();
+ instruction = absl::make_unique<HloDomainInstruction>(
+ proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
+ /*user_side_metadata=*/nullptr);
+ break;
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -472,20 +494,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
- if (instruction->opcode() == HloOpcode::kDot) {
- instruction->precision_config_ = proto.precision_config();
- instruction->precision_config_.mutable_operand_precision()->Resize(
- instruction->operand_count(), PrecisionConfig::DEFAULT);
- TF_RET_CHECK(proto.has_dot_dimension_numbers());
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(
- proto.dot_dimension_numbers());
- } else {
- TF_RET_CHECK(!proto.has_precision_config())
- << instruction->opcode() << proto.DebugString();
- TF_RET_CHECK(!proto.has_dot_dimension_numbers())
- << instruction->opcode();
- }
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
break;
}
}
@@ -564,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kClz:
- case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
@@ -596,7 +606,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
- case HloOpcode::kDot:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
@@ -674,30 +683,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(dimension_numbers);
- instruction->set_precision_config(precision_config);
- return instruction;
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
- CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
- CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
-
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>();
- instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
- instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
- return instruction;
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dimension_numbers, precision_config);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -1157,12 +1144,9 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
- instruction->operand_side_metadata_ = std::move(operand_side_metadata);
- instruction->user_side_metadata_ = std::move(user_side_metadata);
- instruction->AppendOperand(operand);
- return instruction;
+ return absl::make_unique<HloDomainInstruction>(
+ shape, operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata));
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
@@ -1218,6 +1202,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
+ case HloOpcode::kDot:
+ case HloOpcode::kDomain:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1290,11 +1276,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
- case HloOpcode::kDot:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_, precision_config());
- break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
@@ -1319,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kDomain:
- CHECK_EQ(new_operands.size(), 1);
- clone =
- CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
- user_side_metadata_->Clone());
- break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
@@ -1620,11 +1595,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAfterAll:
return false;
- // Check dot dimension numbers.
- case HloOpcode::kDot:
- return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
- other.dot_dimension_numbers());
-
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
@@ -1640,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath(
return false;
}
- case HloOpcode::kDomain:
- return operand_side_metadata().Matches(other.operand_side_metadata()) &&
- user_side_metadata().Matches(other.user_side_metadata());
-
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
@@ -1683,6 +1649,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
+ case HloOpcode::kDot:
+ case HloOpcode::kDomain:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -2052,15 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<string> extra = ExtraAttributesToStringImpl(options);
- if (dot_dimension_numbers_ != nullptr) {
- extra.push_back(DotDimensionNumbersToString());
- }
-
- string precision_config_string = PrecisionConfigToString();
- if (!precision_config_string.empty()) {
- extra.push_back(precision_config_string);
- }
-
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
@@ -2146,11 +2105,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}),
"}"));
}
- if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
- extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
- "\", entry=", user_side_metadata_->ToString(),
- ", exit=", operand_side_metadata_->ToString(), "}"));
- }
return extra;
}
@@ -2182,19 +2136,12 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) {
- *proto.mutable_precision_config() = precision_config_;
- }
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
}
}
- if (dot_dimension_numbers_ != nullptr) {
- *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
- }
-
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
@@ -2921,31 +2868,6 @@ string ConvolutionDimensionNumbersToString(
StrJoin(output_dims, ""));
}
-string HloInstruction::DotDimensionNumbersToString() const {
- std::vector<string> result;
- if (dot_dimension_numbers_ == nullptr) {
- return "";
- }
- const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
- if (!dnums.lhs_batch_dimensions().empty()) {
- result.push_back(StrCat("lhs_batch_dims={",
- StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("lhs_contracting_dims={",
- StrJoin(dnums.lhs_contracting_dimensions(), ","),
- "}"));
-
- if (!dnums.rhs_batch_dimensions().empty()) {
- result.push_back(StrCat("rhs_batch_dims={",
- StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("rhs_contracting_dims={",
- StrJoin(dnums.rhs_contracting_dimensions(), ","),
- "}"));
-
- return StrJoin(result, ", ");
-}
-
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
static std::unordered_map<string, RandomDistribution>* map = [] {
static auto* map = new std::unordered_map<string, RandomDistribution>;
@@ -2964,27 +2886,6 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
return found->second;
}
-string HloInstruction::PrecisionConfigToString() const {
- if (absl::c_all_of(
- precision_config_.operand_precision(), [](int32 precision) {
- return static_cast<PrecisionConfig::Precision>(precision) ==
- PrecisionConfig::DEFAULT;
- })) {
- return "";
- }
- return StrCat(
- "operand_precision={",
- StrJoin(
- precision_config_.operand_precision(), ",",
- [](string* out, int32 precision) {
- CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
- StrAppend(out,
- PrecisionToString(
- static_cast<PrecisionConfig::Precision>(precision)));
- }),
- "}");
-}
-
StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
static auto* map =
@@ -3044,6 +2945,16 @@ Status HloInstruction::set_backend_config(
return ret;
}
+const PrecisionConfig& HloInstruction::precision_config() const {
+ if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->precision_config();
+ }
+ if (auto* dot = DynCast<HloDotInstruction>(this)) {
+ return dot->precision_config();
+ }
+ LOG(FATAL) << "Unimplemented method.";
+}
+
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
@@ -3348,4 +3259,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
+const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
+ return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
+}
+
+const DomainMetadata& HloInstruction::operand_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->operand_side_metadata();
+}
+
+const DomainMetadata& HloInstruction::user_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->user_side_metadata();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 691f8155f9..5581c17c2d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -421,12 +421,6 @@ class HloInstruction {
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
- // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
- // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
- // and the RHS must be of rank 2.
- static std::unique_ptr<HloInstruction> CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
-
// Creates a reduce-precision op, where operand is the data to reduce in
// precision, and exponent_bits and mantissa_bits describe the precision to
// reduce it to.
@@ -866,11 +860,6 @@ class HloInstruction {
return false;
}
- if (!absl::c_equal(precision_config_.operand_precision(),
- other.precision_config_.operand_precision())) {
- return false;
- }
-
return IdenticalSlowPath(other, eq_computations);
}
@@ -1085,15 +1074,6 @@ class HloInstruction {
return other->has_sharding() ? sharding() == other->sharding() : false;
}
- // Retrieves the operand side metadata of a kDomain instruction.
- const DomainMetadata& operand_side_metadata() const {
- return *operand_side_metadata_;
- }
- // Retrieves the user side metadata of a kDomain instruction.
- const DomainMetadata& user_side_metadata() const {
- return *user_side_metadata_;
- }
-
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
// properties of the new instruction are copied into the derived one. As of
@@ -1101,18 +1081,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // Returns data on the dimension numbers used for a dot operation.
- const DotDimensionNumbers& dot_dimension_numbers() const {
- CHECK(dot_dimension_numbers_ != nullptr);
- return *dot_dimension_numbers_;
- }
-
- // Returns the dump string of the dot dimension numbers.
- string DotDimensionNumbersToString() const;
-
- // Returns the dump string of the precision configuration.
- string PrecisionConfigToString() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1262,10 +1230,8 @@ class HloInstruction {
// information. Transformations to other HLOs will not preserve this
// information but it is presumed that the alternate lowering is strictly
// superior.
- const PrecisionConfig& precision_config() const { return precision_config_; }
- void set_precision_config(const PrecisionConfig& precision_config) {
- precision_config_ = precision_config;
- }
+ // Precondition: opcode must be kConvolution or kDot.
+ const PrecisionConfig& precision_config() const;
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
@@ -1508,6 +1474,15 @@ class HloInstruction {
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
+ // Delegates to HloDotInstruction::dot_dimension_numbers().
+ const DotDimensionNumbers& dot_dimension_numbers() const;
+
+ // Delegates to HloDomainInstruction::operand_side_metadata().
+ const DomainMetadata& operand_side_metadata() const;
+
+ // Delegates to HloDomainInstruction::user_side_metadata().
+ const DomainMetadata& user_side_metadata() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1647,22 +1622,12 @@ class HloInstruction {
// Result shape of this instruction.
Shape shape_;
- // Describes the dimension numbers used for a dot.
- std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
-
- // Used to tag kCopy instructions that are eligible for copy elision.
- bool copy_elision_allowed_ = true;
-
// The sharding, if one exists.
// Uses std::shared_ptr to allow reuse of the same sharding object between
// HloInstructions and other components as HloSharding can be very large for
// many element tuples.
std::shared_ptr<const HloSharding> sharding_;
- // Fields used by the kDomain instruction.
- std::unique_ptr<DomainMetadata> operand_side_metadata_;
- std::unique_ptr<DomainMetadata> user_side_metadata_;
-
// Computations called by this instruction.
std::vector<HloComputation*> called_computations_;
@@ -1676,10 +1641,6 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
- // Information used to communicate to the implementation about the algorithm
- // used to produce results. See the documentation on precision_config().
- PrecisionConfig precision_config_;
-
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ad87aa1123..fb7345a2ad 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
return instruction->IsElementwiseOnOperand(operand_index);
});
}
+
+string PrecisionConfigToString(const PrecisionConfig& precision_config) {
+ if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfig::Precision>(precision) ==
+ PrecisionConfig::DEFAULT;
+ })) {
+ return "";
+ }
+
+ return StrCat(
+ "operand_precision={",
+ StrJoin(
+ precision_config.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+ StrAppend(out,
+ PrecisionToString(
+ static_cast<PrecisionConfig::Precision>(precision)));
+ }),
+ "}");
+}
} // namespace
HloBatchNormInstruction::HloBatchNormInstruction(
@@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction(
: HloInstruction(HloOpcode::kConvolution, shape),
feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
- set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1663,6 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
proto.set_feature_group_count(feature_group_count_);
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
@@ -1677,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
if (feature_group_count_ != 1) {
extra.push_back(StrCat("feature_group_count=", feature_group_count_));
}
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
return extra;
}
@@ -1692,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
- casted_other.convolution_dimension_numbers());
+ casted_other.convolution_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction>
@@ -1702,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
shape, new_operands[0], new_operands[1], feature_group_count_, window(),
- convolution_dimension_numbers_, precision_config());
+ convolution_dimension_numbers_, precision_config_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -2161,4 +2191,113 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
}
+HloDotInstruction::HloDotInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config)
+ : HloInstruction(HloOpcode::kDot, shape),
+ dot_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
+ AppendOperand(lhs);
+ AppendOperand(rhs);
+}
+
+HloInstructionProto HloDotInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
+ *proto.mutable_precision_config() = precision_config_;
+ return proto;
+}
+
+std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra = {DotDimensionNumbersToString()};
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+ return extra;
+}
+
+bool HloDotInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDotInstruction&>(other);
+ return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+ casted_other.dot_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
+}
+
+std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return absl::make_unique<HloDotInstruction>(
+ shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
+ precision_config_);
+}
+
+string HloDotInstruction::DotDimensionNumbersToString() const {
+ std::vector<string> result;
+ const DotDimensionNumbers& dnums = dot_dimension_numbers_;
+ if (!dnums.lhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("lhs_batch_dims={",
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("lhs_contracting_dims={",
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
+
+ if (!dnums.rhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("rhs_batch_dims={",
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("rhs_contracting_dims={",
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
+
+ return StrJoin(result, ", ");
+}
+
+HloDomainInstruction::HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata)
+ : HloInstruction(HloOpcode::kDomain, shape),
+ operand_side_metadata_(std::move(operand_side_metadata)),
+ user_side_metadata_(std::move(user_side_metadata)) {
+ AppendOperand(operand);
+}
+
+std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}")};
+ }
+ return {};
+}
+
+bool HloDomainInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
+ return operand_side_metadata().Matches(
+ casted_other.operand_side_metadata()) &&
+ user_side_metadata().Matches(casted_other.user_side_metadata());
+}
+
+std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return absl::make_unique<HloDomainInstruction>(
+ shape, new_operands[0], operand_side_metadata_->Clone(),
+ user_side_metadata_->Clone());
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e1215a7566..c3a7801164 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -957,6 +957,16 @@ class HloConvolutionInstruction : public HloInstruction {
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count() const { return feature_group_count_; }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -979,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction {
Window window_;
// Describes the dimension numbers used for a convolution.
ConvolutionDimensionNumbers convolution_dimension_numbers_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
};
class HloReduceWindowInstruction : public HloInstruction {
@@ -1271,6 +1284,85 @@ class HloIotaInstruction : public HloInstruction {
const int64 iota_dimension_;
};
+class HloDotInstruction : public HloInstruction {
+ public:
+ // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
+ // dimensions specified in 'dimension_numbers'.
+ explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config);
+
+ // Returns data on the dimension numbers used for a dot operation.
+ const DotDimensionNumbers& dot_dimension_numbers() const {
+ return dot_dimension_numbers_;
+ }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+ // Returns the dump string of the dot dimension numbers.
+ string DotDimensionNumbersToString() const;
+
+ // Describes the dimension numbers used for a dot.
+ DotDimensionNumbers dot_dimension_numbers_;
+
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
+};
+
+class HloDomainInstruction : public HloInstruction {
+ public:
+ explicit HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata);
+
+ // Retrieves the operand side metadata of a kDomain instruction.
+ const DomainMetadata& operand_side_metadata() const {
+ return *operand_side_metadata_;
+ }
+ // Retrieves the user side metadata of a kDomain instruction.
+ const DomainMetadata& user_side_metadata() const {
+ return *user_side_metadata_;
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<DomainMetadata> operand_side_metadata_;
+ std::unique_ptr<DomainMetadata> user_side_metadata_;
+};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 3a1bc4e328..cfe906d9c5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
return const_cast<HloInstruction*>(hlo);
}
+Status HloModule::set_schedule(HloSchedule schedule) {
+ TF_RET_CHECK(schedule.module() == this);
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ schedule_ = std::move(schedule);
+ return Status::OK();
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
@@ -198,12 +206,23 @@ void HloModule::ReplaceComputations(
string HloModule::ToString(const HloPrintOptions& options) const {
std::ostringstream s;
- s << "HloModule " << name() << "\n\n";
+ s << "HloModule " << name();
+ if (has_schedule()) {
+ TF_CHECK_OK(schedule().Verify());
+ s << ", is_scheduled=true";
+ }
+ s << "\n\n";
for (const HloComputation* computation : MakeComputationPostOrder()) {
if (computation == entry_computation()) {
s << "ENTRY ";
}
- s << computation->ToString(options) << "\n\n";
+ if (has_schedule() && schedule().is_computation_scheduled(computation)) {
+ s << computation->ToString(
+ options, schedule().sequence(computation).instructions())
+ << "\n\n";
+ } else {
+ s << computation->ToString(options) << "\n\n";
+ }
}
return s.str();
}
@@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const {
}
proto.add_computations()->Swap(&computation_proto);
}
+ if (has_schedule()) {
+ *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
+ }
return proto;
}
@@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
}
+ if (proto.has_schedule()) {
+ TF_ASSIGN_OR_RETURN(
+ HloSchedule schedule,
+ HloSchedule::CreateFromProto(module.get(), proto.schedule()));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ }
+
return std::move(module);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3c3371426b..26fd1b2438 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -235,6 +237,19 @@ class HloModule {
StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
const HloInstruction* hlo);
+ // Sets the schedule of the module to the given schedule.
+ Status set_schedule(HloSchedule schedule);
+
+ // Clears the schedule of the module.
+ void clear_schedule() { schedule_.reset(); }
+
+ // Returns true if the module has a schedule set.
+ bool has_schedule() const { return schedule_.has_value(); }
+
+ // Returns the schedue of the module. CHECK fails if no schedule is set.
+ const HloSchedule& schedule() const { return *schedule_; }
+ HloSchedule& schedule() { return *schedule_; }
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@@ -262,6 +277,11 @@ class HloModule {
static std::atomic<int> next_unique_module_id_;
// A unique id to label modules with.
int unique_id_;
+
+ // The HloSchedule of the module. The schedule if it exists contains a
+ // sequential order of instructions for each non-fusion computation in the
+ // module.
+ absl::optional<HloSchedule> schedule_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 3f1e1cc73e..68c18836eb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -106,9 +106,6 @@ class HloModuleConfig {
absl::optional<ComputationLayout> entry_computation_layout_;
- // Whether this is a 'host module'.
- bool is_host_module_ = false;
-
// Module/graph-level seed handle.
uint64 seed_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 4bc1bacd7d..400bd4d947 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -19,9 +19,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -30,6 +33,8 @@ namespace xla {
namespace {
+namespace op = ::xla::testing::opcode_matchers;
+
class HloModuleTest : public HloTestBase {
protected:
HloModuleTest() {}
@@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) {
EXPECT_NE(module_a->unique_id(), module_b->unique_id());
}
+TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_FALSE(module_copy->has_schedule());
+}
+
+TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_TRUE(module_copy->has_schedule());
+ TF_ASSERT_OK(module_copy->schedule().Verify());
+ EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
+ ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
+ module_copy->entry_computation()));
+ EXPECT_THAT(
+ module_copy->schedule()
+ .sequence(module_copy->entry_computation())
+ .instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 2105f7a349..f1dc08bafa 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
!LiveRangeStrictlyBefore(b, a, dataflow);
}
-HloOrderingProto HloOrdering::ToProto() const {
- HloOrderingProto proto;
- for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- SequentialOrder(*computation);
- if (sequence != nullptr) {
- HloOrderingProto::SequentialComputation* proto_computation =
- proto.add_sequential_computations();
- proto_computation->set_computation_name(computation->name());
- for (const HloInstruction* instruction : *sequence) {
- *proto_computation->add_instruction_names() = instruction->name();
- }
- }
- }
- return proto;
-}
-
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
: HloOrdering(module) {}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b21071c4b2..b0361c3f02 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -72,10 +72,6 @@ class HloOrdering {
virtual string ToString() const = 0;
- // Returns the serialized representation of this ordering.
- // Only sequential computation orders are represented.
- HloOrderingProto ToProto() const;
-
protected:
// Returns true if instruction 'a' executes before instruction 'b'.
// Precondition: 'a' and 'b' are in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 0f26ed4235..c54360b063 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,6 +45,20 @@ using absl::StrJoin;
const double kF16max = 65504;
+// Creates and returns a schedule created using the order of the instructions in
+// the HloComputation::instructions() vectors in the module.
+HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
+ HloSchedule schedule(module);
+ for (const HloComputation* computation : module->computations()) {
+ if (!computation->IsFusionComputation()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ schedule.GetOrCreateSequence(computation).push_back(instruction);
+ }
+ }
+ }
+ return schedule;
+}
+
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
@@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() {
return false;
}
+ absl::optional<bool> is_scheduled;
+ std::unordered_map<string, AttrConfig> attrs;
+ attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
+ if (!ParseAttributes(attrs)) {
+ return false;
+ }
+
module_ = absl::make_unique<HloModule>(name, config_);
- return ParseComputations();
+ if (!ParseComputations()) {
+ return false;
+ }
+
+ if (is_scheduled.has_value() && *is_scheduled) {
+ TF_CHECK_OK(
+ module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ }
+
+ return true;
}
// computations ::= (computation)+
@@ -1248,11 +1279,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
+ optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
+ attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+ &feature_group_count};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -1264,6 +1298,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (dnums.has_value()) {
instruction->set_convolution_dimension_numbers(*dnums);
}
+ if (feature_group_count.has_value()) {
+ instruction->set_feature_group_count(*feature_group_count);
+ }
break;
}
case HloOpcode::kDot: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 0dfc0a4d1c..cca50fab54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1123,18 +1123,31 @@ ENTRY Iota {
)"
},
-// custom-call with window and dim_labels
+// custom-call with window, dim_labels and feature_group_count
{
-"CustomCallWithWindowAndDimLabels",
-R"(HloModule CustomCallWithWindowAndDimLabels
+"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
+R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
ENTRY Computation {
- ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target"
+ ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
}
)"
+ },
+// is_scheduled=true attribute
+{
+"ScheduledModule",
+R"(HloModule scheduled_module, is_scheduled=true
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
}
- });
+
+)"
+}
+});
// clang-format on
}
@@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
EXPECT_EQ(convolution->feature_group_count(), 1);
}
+TEST_F(HloParserTest, IsScheduledIsFalse) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=false
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledNotPresent) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrue) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
+ op::Multiply(), op::Parameter(), op::Add()));
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
+ // As above but in with a different schedule order.
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index 3460679558..b9c0b0c4ee 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -23,11 +23,8 @@ namespace xla {
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
- HloOrderingProto proto_ordering =
- assignment.liveness().hlo_ordering().ToProto();
BufferAssignmentProto proto_assignment = assignment.ToProto();
HloProto proto = MakeHloProto(module);
- proto.mutable_hlo_ordering()->Swap(&proto_ordering);
proto.mutable_buffer_assignment()->Swap(&proto_assignment);
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index a65b33bf40..3fc5dbeb02 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -21,12 +21,64 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace xla {
+/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
+ const HloModule* module, const HloScheduleProto& proto) {
+ tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ for (const HloComputation* computation : module->computations()) {
+ id_to_computation[computation->unique_id()] = computation;
+ }
+
+ HloSchedule schedule(module);
+ for (const auto& id_sequence : proto.sequences()) {
+ int64 computation_id = id_sequence.first;
+
+ auto comp_it = id_to_computation.find(computation_id);
+ TF_RET_CHECK(comp_it != id_to_computation.end())
+ << "No computation exists in HLO module with id " << computation_id;
+ const HloComputation* computation = comp_it->second;
+
+ tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ id_to_instruction[instruction->unique_id()] = instruction;
+ }
+
+ HloInstructionSequence& sequence =
+ schedule.GetOrCreateSequence(computation);
+ for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
+ auto instr_it = id_to_instruction.find(instruction_id);
+ TF_RET_CHECK(instr_it != id_to_instruction.end())
+ << "No instruction exists in HLO computation " << computation->name()
+ << " with id " << instruction_id;
+ sequence.push_back(instr_it->second);
+ }
+ }
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ return std::move(schedule);
+}
+
+StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
+ TF_RETURN_IF_ERROR(Verify());
+ HloScheduleProto proto;
+ for (const auto& id_sequence : sequences_) {
+ int64 computation_id = id_sequence.first;
+ const HloInstructionSequence& sequence = id_sequence.second;
+ HloScheduleProto::InstructionSequence& proto_sequence =
+ (*proto.mutable_sequences())[computation_id];
+ proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
+ for (const int64 id : sequence.ids()) {
+ proto_sequence.add_instruction_ids(id);
+ }
+ }
+ return std::move(proto);
+}
+
void HloSchedule::set_sequence(
const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
index 21c6988638..270fe6039f 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -21,18 +21,20 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/status.h"
namespace xla {
+class HloModule;
+
// Class representing a sequence of HLO instructions such as the sequential
// execution order of an HLO computation.
class HloInstructionSequence {
public:
HloInstructionSequence() = default;
- HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ explicit HloInstructionSequence(
+ absl::Span<const HloInstruction* const> instructions) {
for (const HloInstruction* instruction : instructions) {
push_back(instruction);
}
@@ -77,7 +79,12 @@ class HloInstructionSequence {
// non-fusion computation in the HLO module.
class HloSchedule {
public:
- HloSchedule(const HloModule* module) : module_(module) {}
+ explicit HloSchedule(const HloModule* module) : module_(module) {}
+
+ // (De)Serialize an HloSchedule to/from a HloScheduleProto.
+ static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
+ const HloScheduleProto& proto);
+ StatusOr<HloScheduleProto> ToProto() const;
// Returns a reference to the sequence for the given computation.
const HloInstructionSequence& sequence(
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 6e17711f57..082bf8bffe 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -855,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
? instruction.sharding().GetSubSharding(instruction.shape(), index)
: instruction.sharding();
// We propagate the sharding to the copied instruction only if it is a
- // special sharding, like tiled ones, or special devices like the
- // HostCompute module.
+ // special sharding, like tiled ones.
// Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location.
auto device = sharding.UniqueDevice();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 7d49b8d6c2..a60643bc75 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -75,6 +75,16 @@ void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
}
}
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+ llvm::IRBuilder<>* b, llvm::Module* module) {
+ std::vector<llvm::Value*> buffer_ptrs;
+ buffer_ptrs.reserve(buffers.size());
+ absl::c_transform(
+ buffers, std::back_inserter(buffer_ptrs),
+ [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); });
+ llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module);
+}
+
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* b, llvm::Module* module) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index 887fb61371..94340b91d8 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -68,6 +68,11 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module);
+// Similar to EmitTuple above, except that the output buffers are provided in
+// the form of IrArray.
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+ llvm::IRBuilder<>* b, llvm::Module* module);
+
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand. A GetTupleElement instruction
// forwards the pointer to underlying tuple element buffer at the given index.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 36b8fb2644..d0bda45cf8 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -75,7 +75,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
- "//tensorflow/core:stream_executor_headers_lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 8c62adea23..57f7fed61f 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -866,10 +866,7 @@ INSTANTIATE_TEST_CASE_P(
BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
-// TODO(b/64093391) Disabled on GPU due to an assertion failure when running
-// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on
-// 2017-07-26.
-XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
+XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
XlaBuilder builder(TestName());
XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index c20a7c8fe4..3ae31191a0 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
.status();
}
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+ CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dot_dimension_numbers, precision_config);
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 7790737c09..a260271b1b 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -24,10 +24,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/platform.h"
namespace xla {
@@ -98,6 +98,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
bool allow_mixed_precision);
+// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of
+// the LHS with dimension 0 of the RHS with no batch dimensions.
+// Both LHS and the RHS must be of rank 2.
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
index 23ce1d235b..0c3ec5934e 100644
--- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -67,8 +67,8 @@ int main(int argc, char** argv) {
floats.push_back(value);
}
- absl::string_view content(absl::bit_cast<const char*>(floats.data()),
- floats.size() * sizeof(float));
+ tensorflow::StringPiece content(absl::bit_cast<const char*>(floats.data()),
+ floats.size() * sizeof(float));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
output_file, content));
return 0;