diff options
author | 2017-09-06 14:01:13 -0700 | |
---|---|---|
committer | 2017-09-06 14:05:01 -0700 | |
commit | 570147d34015cb8c36869a548b7eac0409416601 (patch) | |
tree | c49d836d9a997317f662eb9dc30c270b2ee96c80 /tensorflow/compiler/xla/service/hlo_evaluator_test.cc | |
parent | e089c5570468e751ee2842e3018ebe67b513e78c (diff) |
[TF:XLA] In Literal: correctly handle operands with zero elements in
Copy.
PiperOrigin-RevId: 167769308
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_test.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index a826548349..9205f5dc4e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -332,6 +332,53 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { LiteralTestUtil::ExpectEqual(*result, *output_literal); } +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { + HloComputation::Builder b(TestName()); + + HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2<int64>({{-1, -2}, {100, 200}}))); + HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2<int64>({{-2, -3}, {-100, -200}}))); + + std::vector<HloInstruction*> operands = {operand1, operand2}; + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(b.Build()); + + std::unique_ptr<Literal> result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = + Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { + HloComputation::Builder b(TestName()); + + HloInstruction* operand1 = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<int64>({100, 200}))); + HloInstruction* operand2 = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<int64>({}))); + + std::vector<HloInstruction*> operands = {operand1, operand2}; + + Shape shape = ShapeUtil::MakeShape(S64, {2}); + b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(b.Build()); + + std::unique_ptr<Literal> result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR1<int64>({100, 200}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); |