diff options
author | 2018-09-06 20:09:38 -0700 | |
---|---|---|
committer | 2018-09-06 20:14:24 -0700 | |
commit | ac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (patch) | |
tree | 06840591db9d2a077b28fe28f73baae913065550 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | 1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (diff) |
Split out HloDotInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 211912785
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index c20a7c8fe4..3ae31191a0 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, .status(); } +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla |