aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc84
1 files changed, 44 insertions, 40 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index a8f3f0e9c2..e4c5470331 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <set>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) {
// Test GetInstructionPostOrder for a computation with one instruction.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
@@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) {
// instructions.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto negate2 = builder.AddInstruction(
@@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) {
// Test GetInstructionPostOrder for a computation with a trace instruction.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto trace =
@@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
// which are not connected.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(),
@@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
// which are not connected.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
// computation has multiple roots (dead code).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
// Add three disconnected add expressions.
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
constant1, constant2));
@@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) {
// Test that DeepCopyInstruction properly copies an array.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
@@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) {
// Test that DeepCopyInstruction properly copies a tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
// copy are specified.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto computation = builder.Build();
{
@@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
// specified by the given indices.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto computation = builder.Build();
@@ -375,7 +375,7 @@ TEST_F(HloComputationTest, DeepCopyToken) {
// Test that DeepCopyInstruction properly handles tokens which should not be
// copied.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
@@ -388,9 +388,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) {
// Test that DeepCopyInstruction properly handles tokens which should not be
// copied.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
auto module = CreateNewModule();
@@ -407,7 +407,7 @@ TEST_F(HloComputationTest, CycleDetection) {
// Test whether the visitor can detect cycles in the graph.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto add = builder.AddInstruction(
@@ -433,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
// twice.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto dead_negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -456,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
TEST_F(HloComputationTest, CloneWithControlDependency) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
@@ -502,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) {
// There is a control dependency from 'add' to 'exp'.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto negate = builder.AddInstruction(
@@ -607,13 +607,14 @@ TEST_F(HloComputationTest, Stringification) {
auto* computation = module->AddEntryComputation(builder.Build());
auto options = HloPrintOptions().set_print_metadata(false);
- EXPECT_EQ(computation->ToString(options),
- R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation =
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation);
}
TEST_F(HloComputationTest, StringificationIndent) {
@@ -639,13 +640,14 @@ TEST_F(HloComputationTest, StringificationIndent) {
auto options =
HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
- EXPECT_EQ(computation->ToString(options),
- R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation =
+ R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })");
+ })";
+ EXPECT_EQ(computation->ToString(options), expected_computation);
}
TEST_F(HloComputationTest, StringificationCanonical) {
@@ -670,21 +672,23 @@ TEST_F(HloComputationTest, StringificationCanonical) {
auto* computation = module->AddEntryComputation(builder.Build());
auto options = HloPrintOptions().set_print_metadata(false);
- EXPECT_EQ(computation->ToString(options),
- R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation1 =
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation1);
options = HloPrintOptions().Canonical();
- EXPECT_EQ(computation->ToString(options), R"(TransposeDot {
+ const string expected_computation2 = R"(TransposeDot {
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation2);
}
} // namespace