aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 14:07:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 14:12:38 -0700
commitbdd84aa59d3bdedc42647711e401229f489c7d25 (patch)
tree695399ae3fed6bc65177f38493dee35f5f74e116 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parenta6471888cc9dfe9c18d121149bc0516a3f423fbb (diff)
[TF:XLA] Split select HLO into array- and tuple-select.
Array select and tuple-select already are handled separately in all backends and HLO passes: Array select is an elementwise operation. The shapes of the to operands have the same dimensions. Tuple select does not define its own output, but instead forwards the true- or false- operand based on a scalar predicate operand. This CL reflects this by adding a new kTupleSelect HLO. The XLA builder interface stays the same and dispatches based on the operand shapes. No change in the operation semantics. This CL just splits the existing select operation into two opcodes and preserves the existing semantics. HLO cost analysis is fixed to handle the two ops appropriately. PiperOrigin-RevId: 203180342
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc21
1 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 343f5e7b39..f176473366 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -860,8 +860,7 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) {
}
TEST_P(HloDataflowAnalysisTest, TupleSelect) {
- // Test a kSelect of a tuple value. Non-top-level element flow through the
- // instruction.
+ // Test a kTupleSelect. Non-top-level element flow through the instruction.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
@@ -883,20 +882,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
const Shape tuple_shape = tuple1->shape();
auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, select12, select34));
+ tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- // Top-level value is always defined by a kSelect.
+ // Top-level value is always defined by a kTupleSelect.
EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
@@ -937,7 +936,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
}
TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
- // Test kSelect of a nested tuple.
+ // Test kTupleSelect of a nested tuple.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
@@ -960,7 +959,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant4, inner_tuple2}));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
module_->AddEntryComputation(builder.Build());
@@ -983,7 +982,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
}
TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
- // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO:
+ // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
//
// body((F32[], F32[]) %tuple_param):
// %add = Add(%tuple_param{0}, %tuple_param{1})
@@ -1043,7 +1042,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
auto tuple2 =
builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
auto gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
auto tuple =