aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2017-09-06 14:01:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 14:05:01 -0700
commit570147d34015cb8c36869a548b7eac0409416601 (patch)
treec49d836d9a997317f662eb9dc30c270b2ee96c80 /tensorflow/compiler/xla/service/hlo_evaluator_test.cc
parente089c5570468e751ee2842e3018ebe67b513e78c (diff)
[TF:XLA] In Literal: correctly handle operands with zero elements in
Copy. PiperOrigin-RevId: 167769308
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index a826548349..9205f5dc4e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -332,6 +332,53 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
LiteralTestUtil::ExpectEqual(*result, *output_literal);
}
+TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
+ HloComputation::Builder b(TestName());
+
+ HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<int64>({{-1, -2}, {100, 200}})));
+ HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<int64>({{-2, -3}, {-100, -200}})));
+
+ std::vector<HloInstruction*> operands = {operand1, operand2};
+
+ Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
+ b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(b.Build());
+
+ std::unique_ptr<Literal> result =
+ evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie();
+
+ auto expected =
+ Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
+ LiteralTestUtil::ExpectEqual(*expected, *result);
+}
+
+TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
+ HloComputation::Builder b(TestName());
+
+ HloInstruction* operand1 = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int64>({100, 200})));
+ HloInstruction* operand2 = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int64>({})));
+
+ std::vector<HloInstruction*> operands = {operand1, operand2};
+
+ Shape shape = ShapeUtil::MakeShape(S64, {2});
+ b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(b.Build());
+
+ std::unique_ptr<Literal> result =
+ evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie();
+
+ auto expected = Literal::CreateR1<int64>({100, 200});
+ LiteralTestUtil::ExpectEqual(*expected, *result);
+}
+
TEST_F(HloEvaluatorTest, ConvertWithSameLayout) {
HloComputation::Builder b(TestName());