diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-06 20:09:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 20:14:24 -0700 |
commit | ac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (patch) | |
tree | 06840591db9d2a077b28fe28f73baae913065550 | |
parent | 1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (diff) |
Split out HloDotInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 211912785
16 files changed, 226 insertions, 179 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d412578619..2368ac8c6a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -670,6 +670,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 0fea462c85..7d99b914d4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace op = xla::testing::opcode_matchers; @@ -696,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, auto* addend = builder.AddInstruction( HloInstruction::CreateParameter(2, dot_shape, "param2")); - auto* dot = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + auto* dot = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); builder.AddInstruction( HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 9363af3b89..4668f3872d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion( HloInstruction::CreateParameter(1, dot_shape, "param1")); HloInstruction* dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); - HloInstruction* dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* dot_result = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); HloInstruction* add_result; if (dot_operand_idx_in_add == 0) { add_result = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 2384166fd2..f11aff0573 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -121,6 +121,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index fcd87b36b3..18ee25ba91 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { HloInstruction* rhs = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "input")); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } @@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { HloInstruction* lhs_transposed = builder.AddInstruction( HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a68b7a1bef..6791e15ee0 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -480,6 +481,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -830,6 +832,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 0922e44a12..59ade96f7d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* add = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index bca775c475..96bfe0c12e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 091aca23e5..8f0dedfa40 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 471a12d6aa..563aa695c9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -451,6 +451,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; + case HloOpcode::kDot: { + TF_RET_CHECK(proto.has_dot_dimension_numbers()) + << "Dot instruction should have dot_dimension_numbers."; + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Dot instruction should have 2 operands but sees " + << proto.operand_ids_size(); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = absl::make_unique<HloDotInstruction>( + proto.shape(), operands(0), operands(1), + proto.dot_dimension_numbers(), precision_config); + break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -472,20 +486,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } - if (instruction->opcode() == HloOpcode::kDot) { - instruction->precision_config_ = proto.precision_config(); - instruction->precision_config_.mutable_operand_precision()->Resize( - instruction->operand_count(), PrecisionConfig::DEFAULT); - TF_RET_CHECK(proto.has_dot_dimension_numbers()); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>( - proto.dot_dimension_numbers()); - } else { - TF_RET_CHECK(!proto.has_precision_config()) - << instruction->opcode() << proto.DebugString(); - TF_RET_CHECK(!proto.has_dot_dimension_numbers()) - << instruction->opcode(); - } + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); break; } } @@ -596,7 +599,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kGe: case HloOpcode::kGt: @@ -674,30 +676,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(dimension_numbers); - instruction->set_precision_config(precision_config); - return instruction; -} - -/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(); - instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); - instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); - return instruction; + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dimension_numbers, precision_config); } /* static */ std::unique_ptr<HloInstruction> @@ -1218,6 +1198,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kIota: + case HloOpcode::kDot: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1290,11 +1271,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kDot: - CHECK_EQ(new_operands.size(), 2); - clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_, precision_config()); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); @@ -1620,11 +1596,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAfterAll: return false; - // Check dot dimension numbers. - case HloOpcode::kDot: - return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - other.dot_dimension_numbers()); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1683,6 +1654,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kScatter: + case HloOpcode::kDot: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2052,10 +2024,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector<string> extra = ExtraAttributesToStringImpl(options); - if (dot_dimension_numbers_ != nullptr) { - extra.push_back(DotDimensionNumbersToString()); - } - string precision_config_string = PrecisionConfigToString(); if (!precision_config_string.empty()) { extra.push_back(precision_config_string); @@ -2182,19 +2150,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) { - *proto.mutable_precision_config() = precision_config_; - } if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - if (dot_dimension_numbers_ != nullptr) { - *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; - } - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } @@ -2921,31 +2882,6 @@ string ConvolutionDimensionNumbersToString( StrJoin(output_dims, "")); } -string HloInstruction::DotDimensionNumbersToString() const { - std::vector<string> result; - if (dot_dimension_numbers_ == nullptr) { - return ""; - } - const DotDimensionNumbers& dnums = *dot_dimension_numbers_; - if (!dnums.lhs_batch_dimensions().empty()) { - result.push_back(StrCat("lhs_batch_dims={", - StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("lhs_contracting_dims={", - StrJoin(dnums.lhs_contracting_dimensions(), ","), - "}")); - - if (!dnums.rhs_batch_dimensions().empty()) { - result.push_back(StrCat("rhs_batch_dims={", - StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("rhs_contracting_dims={", - StrJoin(dnums.rhs_contracting_dimensions(), ","), - "}")); - - return StrJoin(result, ", "); -} - StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { static std::unordered_map<string, RandomDistribution>* map = [] { static auto* map = new std::unordered_map<string, RandomDistribution>; @@ -3348,4 +3284,8 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers(); } +const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { + return Cast<HloDotInstruction>(this)->dot_dimension_numbers(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 691f8155f9..de60ddf42d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -421,12 +421,6 @@ class HloInstruction { const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); - // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 - // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS - // and the RHS must be of rank 2. - static std::unique_ptr<HloInstruction> CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); - // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -1101,15 +1095,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns data on the dimension numbers used for a dot operation. - const DotDimensionNumbers& dot_dimension_numbers() const { - CHECK(dot_dimension_numbers_ != nullptr); - return *dot_dimension_numbers_; - } - - // Returns the dump string of the dot dimension numbers. - string DotDimensionNumbersToString() const; - // Returns the dump string of the precision configuration. string PrecisionConfigToString() const; @@ -1508,6 +1493,9 @@ class HloInstruction { // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Delegates to HloDotInstruction::dot_dimension_numbers(). + const DotDimensionNumbers& dot_dimension_numbers() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1647,12 +1635,6 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the dimension numbers used for a dot. - std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; - - // Used to tag kCopy instructions that are eligible for copy elision. - bool copy_elision_allowed_ = true; - // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index ad87aa1123..4e3e0c055e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1663,6 +1663,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config(); return proto; } @@ -2161,4 +2162,66 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); } +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers) { + AppendOperand(lhs); + AppendOperand(rhs); + set_precision_config(precision_config); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config(); + return proto; +} + +std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {DotDimensionNumbersToString()}; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDotInstruction&>(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()); +} + +std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique<HloDotInstruction>( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config()); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector<string> result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e1215a7566..e72ddabff9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1271,6 +1271,41 @@ class HloIotaInstruction : public HloInstruction { const int64 iota_dimension_; }; +class HloDotInstruction : public HloInstruction { + public: + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); + + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + return dot_dimension_numbers_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const override; + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + + // Describes the dimension numbers used for a dot. + DotDimensionNumbers dot_dimension_numbers_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 36b8fb2644..d0bda45cf8 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -75,7 +75,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_headers_lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index c20a7c8fe4..3ae31191a0 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, .status(); } +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 7790737c09..a260271b1b 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -98,6 +98,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( Status VerifyHloModule(HloModule* const module, bool layout_sensitive, bool allow_mixed_precision); +// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of +// the LHS with dimension 0 of the RHS with no batch dimensions. +// Both LHS and the RHS must be of rank 2. +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ |