aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 20:09:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:14:24 -0700
commitac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (patch)
tree06840591db9d2a077b28fe28f73baae913065550
parent1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (diff)
Split out HloDotInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 211912785
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc110
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h24
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc63
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h35
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc14
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h8
16 files changed, 226 insertions, 179 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index d412578619..2368ac8c6a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -670,6 +670,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 0fea462c85..7d99b914d4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace op = xla::testing::opcode_matchers;
@@ -696,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
auto* addend = builder.AddInstruction(
HloInstruction::CreateParameter(2, dot_shape, "param2"));
- auto* dot = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ auto* dot =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
builder.AddInstruction(
HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 9363af3b89..4668f3872d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
builder.AddInstruction(HloInstruction::CreateBinary(
result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
@@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
auto tuple_result = builder.AddInstruction(
HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
@@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, rhs_shape, "param0"));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
HloInstruction::CreateParameter(1, dot_shape, "param1"));
HloInstruction* dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
- HloInstruction* dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ HloInstruction* dot_result =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
HloInstruction* add_result;
if (dot_operand_idx_in_add == 0) {
add_result = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 2384166fd2..f11aff0573 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -121,6 +121,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index fcd87b36b3..18ee25ba91 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) {
HloInstruction* rhs = builder.AddInstruction(
HloInstruction::CreateParameter(1, param_shape, "input"));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
@@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) {
HloInstruction* lhs_transposed = builder.AddInstruction(
HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a68b7a1bef..6791e15ee0 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -480,6 +481,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -830,6 +832,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 0922e44a12..59ade96f7d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
- HloInstruction* add = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* add =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(add));
@@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) {
i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index bca775c475..96bfe0c12e 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 091aca23e5..8f0dedfa40 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
@@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 471a12d6aa..563aa695c9 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -451,6 +451,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
+ case HloOpcode::kDot: {
+ TF_RET_CHECK(proto.has_dot_dimension_numbers())
+ << "Dot instruction should have dot_dimension_numbers.";
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Dot instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ PrecisionConfig precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfig::DEFAULT);
+ instruction = absl::make_unique<HloDotInstruction>(
+ proto.shape(), operands(0), operands(1),
+ proto.dot_dimension_numbers(), precision_config);
+ break;
+ }
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -472,20 +486,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
- if (instruction->opcode() == HloOpcode::kDot) {
- instruction->precision_config_ = proto.precision_config();
- instruction->precision_config_.mutable_operand_precision()->Resize(
- instruction->operand_count(), PrecisionConfig::DEFAULT);
- TF_RET_CHECK(proto.has_dot_dimension_numbers());
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(
- proto.dot_dimension_numbers());
- } else {
- TF_RET_CHECK(!proto.has_precision_config())
- << instruction->opcode() << proto.DebugString();
- TF_RET_CHECK(!proto.has_dot_dimension_numbers())
- << instruction->opcode();
- }
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
break;
}
}
@@ -596,7 +599,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
- case HloOpcode::kDot:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
@@ -674,30 +676,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(dimension_numbers);
- instruction->set_precision_config(precision_config);
- return instruction;
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
- CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
- CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
-
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>();
- instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
- instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
- return instruction;
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dimension_numbers, precision_config);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -1218,6 +1198,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
+ case HloOpcode::kDot:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1290,11 +1271,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
- case HloOpcode::kDot:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_, precision_config());
- break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
@@ -1620,11 +1596,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAfterAll:
return false;
- // Check dot dimension numbers.
- case HloOpcode::kDot:
- return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
- other.dot_dimension_numbers());
-
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
@@ -1683,6 +1654,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
+ case HloOpcode::kDot:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -2052,10 +2024,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<string> extra = ExtraAttributesToStringImpl(options);
- if (dot_dimension_numbers_ != nullptr) {
- extra.push_back(DotDimensionNumbersToString());
- }
-
string precision_config_string = PrecisionConfigToString();
if (!precision_config_string.empty()) {
extra.push_back(precision_config_string);
@@ -2182,19 +2150,12 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) {
- *proto.mutable_precision_config() = precision_config_;
- }
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
}
}
- if (dot_dimension_numbers_ != nullptr) {
- *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
- }
-
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
@@ -2921,31 +2882,6 @@ string ConvolutionDimensionNumbersToString(
StrJoin(output_dims, ""));
}
-string HloInstruction::DotDimensionNumbersToString() const {
- std::vector<string> result;
- if (dot_dimension_numbers_ == nullptr) {
- return "";
- }
- const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
- if (!dnums.lhs_batch_dimensions().empty()) {
- result.push_back(StrCat("lhs_batch_dims={",
- StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("lhs_contracting_dims={",
- StrJoin(dnums.lhs_contracting_dimensions(), ","),
- "}"));
-
- if (!dnums.rhs_batch_dimensions().empty()) {
- result.push_back(StrCat("rhs_batch_dims={",
- StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("rhs_contracting_dims={",
- StrJoin(dnums.rhs_contracting_dimensions(), ","),
- "}"));
-
- return StrJoin(result, ", ");
-}
-
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
static std::unordered_map<string, RandomDistribution>* map = [] {
static auto* map = new std::unordered_map<string, RandomDistribution>;
@@ -3348,4 +3284,8 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
+const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
+ return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 691f8155f9..de60ddf42d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -421,12 +421,6 @@ class HloInstruction {
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
- // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
- // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
- // and the RHS must be of rank 2.
- static std::unique_ptr<HloInstruction> CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
-
// Creates a reduce-precision op, where operand is the data to reduce in
// precision, and exponent_bits and mantissa_bits describe the precision to
// reduce it to.
@@ -1101,15 +1095,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // Returns data on the dimension numbers used for a dot operation.
- const DotDimensionNumbers& dot_dimension_numbers() const {
- CHECK(dot_dimension_numbers_ != nullptr);
- return *dot_dimension_numbers_;
- }
-
- // Returns the dump string of the dot dimension numbers.
- string DotDimensionNumbersToString() const;
-
// Returns the dump string of the precision configuration.
string PrecisionConfigToString() const;
@@ -1508,6 +1493,9 @@ class HloInstruction {
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
+ // Delegates to HloDotInstruction::dot_dimension_numbers().
+ const DotDimensionNumbers& dot_dimension_numbers() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1647,12 +1635,6 @@ class HloInstruction {
// Result shape of this instruction.
Shape shape_;
- // Describes the dimension numbers used for a dot.
- std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
-
- // Used to tag kCopy instructions that are eligible for copy elision.
- bool copy_elision_allowed_ = true;
-
// The sharding, if one exists.
// Uses std::shared_ptr to allow reuse of the same sharding object between
// HloInstructions and other components as HloSharding can be very large for
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ad87aa1123..4e3e0c055e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1663,6 +1663,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
proto.set_feature_group_count(feature_group_count_);
+ *proto.mutable_precision_config() = precision_config();
return proto;
}
@@ -2161,4 +2162,66 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
}
+HloDotInstruction::HloDotInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config)
+ : HloInstruction(HloOpcode::kDot, shape),
+ dot_dimension_numbers_(dimension_numbers) {
+ AppendOperand(lhs);
+ AppendOperand(rhs);
+ set_precision_config(precision_config);
+}
+
+HloInstructionProto HloDotInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
+ *proto.mutable_precision_config() = precision_config();
+ return proto;
+}
+
+std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {DotDimensionNumbersToString()};
+}
+
+bool HloDotInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDotInstruction&>(other);
+ return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+ casted_other.dot_dimension_numbers());
+}
+
+std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return absl::make_unique<HloDotInstruction>(
+ shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
+ precision_config());
+}
+
+string HloDotInstruction::DotDimensionNumbersToString() const {
+ std::vector<string> result;
+ const DotDimensionNumbers& dnums = dot_dimension_numbers_;
+ if (!dnums.lhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("lhs_batch_dims={",
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("lhs_contracting_dims={",
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
+
+ if (!dnums.rhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("rhs_batch_dims={",
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("rhs_contracting_dims={",
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
+
+ return StrJoin(result, ", ");
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e1215a7566..e72ddabff9 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1271,6 +1271,41 @@ class HloIotaInstruction : public HloInstruction {
const int64 iota_dimension_;
};
+class HloDotInstruction : public HloInstruction {
+ public:
+ // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
+ // dimensions specified in 'dimension_numbers'.
+ explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config);
+
+ // Returns data on the dimension numbers used for a dot operation.
+ const DotDimensionNumbers& dot_dimension_numbers() const {
+ return dot_dimension_numbers_;
+ }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+ // Returns the dump string of the dot dimension numbers.
+ string DotDimensionNumbersToString() const;
+
+ // Describes the dimension numbers used for a dot.
+ DotDimensionNumbers dot_dimension_numbers_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 36b8fb2644..d0bda45cf8 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -75,7 +75,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
- "//tensorflow/core:stream_executor_headers_lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index c20a7c8fe4..3ae31191a0 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
.status();
}
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+ CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dot_dimension_numbers, precision_config);
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 7790737c09..a260271b1b 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -24,10 +24,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/platform.h"
namespace xla {
@@ -98,6 +98,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
bool allow_mixed_precision);
+// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of
+// the LHS with dimension 0 of the RHS with no batch dimensions.
+// Both LHS and the RHS must be of rank 2.
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_