aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 20:09:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:14:24 -0700
commitac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (patch)
tree06840591db9d2a077b28fe28f73baae913065550 /tensorflow/compiler/xla/tests/test_utils.cc
parent1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (diff)
Split out HloDotInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 211912785
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index c20a7c8fe4..3ae31191a0 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
.status();
}
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+ CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dot_dimension_numbers, precision_config);
+}
} // namespace xla