diff options
author | 2018-08-29 17:48:09 -0700 | |
---|---|---|
committer | 2018-08-29 17:51:38 -0700 | |
commit | 7cda8c3a8ad528f2e11fc47b0abf08e01f97af45 (patch) | |
tree | 4abceb5e1c3ca6692f41d53d71f8e4de1c4108fb /tensorflow | |
parent | e528493c8cde468451ba1b1995e649ebe9c29b02 (diff) |
[XLA] Switch to using kIota from TF
We were using a broadcast of a constant instead of the kIota HLO.
To make switching to kIota practical, we need to do a few things first:
- Don't constant fold kIota.
- Don't hoist kIota from loops without good cause.
PiperOrigin-RevId: 210825834
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/client/lib/numeric.cc | 47 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/numeric_test.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 55 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 31 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_element_type_converter.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/iota_test.cc | 6 |
13 files changed, 128 insertions, 89 deletions
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc index 02bed80162..38e440c68d 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -23,53 +23,6 @@ limitations under the License. namespace xla { -namespace { - -template <typename T> -XlaOp MakeIota(XlaBuilder* builder, int64 size) { - std::vector<T> values(size); - for (int64 i = 0; i < size; ++i) { - values[i] = static_cast<T>(i); - } - return ConstantR1<T>(builder, values); -} - -} // namespace - -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { - switch (type) { - case S8: - return MakeIota<int8>(builder, size); - case S16: - return MakeIota<int16>(builder, size); - case S32: - return MakeIota<int32>(builder, size); - case S64: - return MakeIota<int64>(builder, size); - case U8: - return MakeIota<uint8>(builder, size); - case U16: - return MakeIota<uint16>(builder, size); - case U32: - return MakeIota<uint32>(builder, size); - case U64: - return MakeIota<uint64>(builder, size); - case BF16: - return MakeIota<bfloat16>(builder, size); - case F16: - return MakeIota<half>(builder, size); - case F32: - return MakeIota<float>(builder, size); - case F64: - return MakeIota<double>(builder, size); - case C64: - return MakeIota<complex64>(builder, size); - default: - return builder->ReportError(InvalidArgument( - "Unimplemented type for Iota: %s.", PrimitiveType_Name(type))); - } -} - XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n) { auto a = Iota(builder, type, m); diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc index 8a96ec68d2..7d6aedd494 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase { void TestMatrixDiagonal(); }; -// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to -// xla::Iota. This test is already implemented for xla::IotaGen in -// xla/tests/iota_test.cc. -XLA_TEST_F(NumericTest, Iota) { - XlaBuilder builder(TestName()); - Iota(&builder, S32, 10); - - ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); -} - XLA_TEST_F(NumericTest, Triangle) { XlaBuilder builder(TestName()); Array3D<int32> input(2, 3, 4); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index ea53287068..531b8dd66b 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -466,7 +466,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } -XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) { +XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; *instr.mutable_shape() = shape; @@ -475,8 +475,8 @@ XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) { }); } -XlaOp XlaBuilder::IotaGen(PrimitiveType type, int64 size) { - return IotaGen(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); +XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) { + return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); } XlaOp XlaBuilder::Call(const XlaComputation& computation, @@ -3063,12 +3063,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } -XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { - return builder->IotaGen(type, size); +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + return builder->Iota(type, size); } -XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { - return builder->IotaGen(shape, iota_dimension); +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { + return builder->Iota(shape, iota_dimension); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 9b82cc03b3..b9e651f2ae 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -808,10 +808,10 @@ class XlaBuilder { XlaOp IsFinite(const XlaOp& operand); // Enqueues an iota operation onto the computation. - XlaOp IotaGen(const Shape& shape, int64 iota_dimension); + XlaOp Iota(const Shape& shape, int64 iota_dimension); // Enqueues a rank-1 iota operation onto the computation. - XlaOp IotaGen(PrimitiveType type, int64 size); + XlaOp Iota(PrimitiveType type, int64 size); // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. @@ -1320,11 +1320,9 @@ class XlaBuilder { friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); - // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota - // in xla/client/lib/numeric.h with this (renamed to xla::Iota). - friend XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, - int64 iota_dimension); - friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, + int64 iota_dimension); + friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp BitcastConvertType(const XlaOp& operand, @@ -1988,10 +1986,10 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, XlaOp IsFinite(const XlaOp& operand); // Enqueues an iota operation onto the computation. -XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); // Enqueues a rank-1 iota operation onto the computation. -XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 93e808469a..3dd0abee79 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1687,6 +1687,61 @@ bool LiteralBase::IsAllFirst() const { }); } +bool LiteralBase::IsR1Iota() const { + if (!ShapeUtil::IsArray(shape())) { + return false; + } + + if (ShapeUtil::Rank(shape()) != 1) { + return false; + } + + auto is_iota_at_idx = [&](const int64 idx) { + switch (shape().element_type()) { + case U8: + return Get<uint8>({idx}) == idx; + case U16: + return Get<uint16>({idx}) == idx; + case U32: + return Get<uint32>({idx}) == idx; + case U64: + return Get<uint64>({idx}) == idx; + case S8: + return Get<int8>({idx}) == idx; + case S16: + return Get<int16>({idx}) == idx; + case S32: + return Get<int32>({idx}) == idx; + case S64: + return Get<int64>({idx}) == idx; + case F32: + return Get<float>({idx}) == idx; + case F64: + return Get<double>({idx}) == idx; + case F16: + return Get<half>({idx}) == static_cast<half>(idx); + case BF16: + return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx); + case C64: + return Get<complex64>({idx}) == complex64(idx, 0.0f); + case PRED: + return Get<bool>({idx}) == idx; + // token, opaque, tuple, etc. are all not iota. + default: + return false; + } + }; + + const int64 elements = ShapeUtil::ElementsIn(shape()); + for (int64 idx = 0; idx < elements; ++idx) { + if (!is_iota_at_idx(idx)) { + return false; + } + } + + return true; +} + bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index aad435ed5b..8370043da1 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -195,6 +195,9 @@ class LiteralBase { // Literal consists entirely of the first element of the literal. bool IsAllFirst() const; + // Literal consists entirely of an iota. + bool IsR1Iota() const; + // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 19bb4da9a6..196865f333 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -553,6 +553,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); } + + // If a literal is an increasing sequence from zero, replace it with an iota. + if (ShapeUtil::Rank(constant->shape()) == 1 && + ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsR1Iota()) { + return ReplaceWithNewInstruction( + constant, HloInstruction::CreateIota(constant->shape(), 0)); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 1900a05750..917ed86b69 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -298,6 +298,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { EXPECT_THAT(root, op::Constant()); } +TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -521,7 +536,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 813e93fafa..def42f9c77 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2117,29 +2117,40 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( iota->shape().dimensions(iota->iota_dimension())}; elem_index_linear = elem_index.Linearize(iota_bound, b_); } - if (ShapeUtil::ElementIsIntegral(iota->shape())) { - return b_->CreateIntCast( + Shape component_shape = + ShapeUtil::ElementIsComplex(iota->shape()) + ? ShapeUtil::ComplexComponentShape(iota->shape()) + : iota->shape(); + PrimitiveType component_element_type = component_shape.element_type(); + llvm::Value* iota_result; + if (ShapeUtil::ElementIsIntegral(component_shape)) { + iota_result = b_->CreateIntCast( elem_index_linear, - llvm_ir::PrimitiveTypeToIrType(element_type, module_), + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), /*isSigned=*/false); } else { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape())) - << element_type; + TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + << component_element_type; llvm::Type* float_ir_type; - if (element_type == BF16) { + if (component_element_type == BF16) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); } else { float_ir_type = - llvm_ir::PrimitiveTypeToIrType(element_type, module_); + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_); } llvm::Value* float_val = b_->CreateUIToFP(elem_index_linear, float_ir_type); - if (element_type == BF16) { - return EmitF32ToBF16(float_val, b_); + if (component_element_type == BF16) { + iota_result = EmitF32ToBF16(float_val, b_); } else { - return float_val; + iota_result = float_val; } } + if (ShapeUtil::ElementIsComplex(iota->shape())) { + return EmitComposeComplex(iota, iota_result, nullptr); + } else { + return iota_result; + } }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 2ed645c3ae..8a45939c61 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -71,7 +71,8 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. - if (instruction->opcode() == HloOpcode::kBroadcast) { + if (instruction->opcode() == HloOpcode::kBroadcast || + instruction->opcode() == HloOpcode::kIota) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index b9244b8e9e..72006e17e7 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -151,7 +151,11 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) { } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); - if (!HasOperandType(hlo, eliminate_type_)) { + bool nullary = hlo->operands().empty(); + bool wrong_element_type = hlo->shape().element_type() == eliminate_type_; + bool should_eliminate_type = (nullary && wrong_element_type) || + HasOperandType(hlo, eliminate_type_); + if (!should_eliminate_type) { // If this CHECK fires, then this was an instruction that does not take // the elimination type as an operand but it does return it. This pass // does not have a feature to change the output type in that case, so diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index f4098f28b3..e8fe33e626 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -110,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: + case HloOpcode::kIota: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 07c3c6b878..310f349592 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -39,7 +39,7 @@ TEST_P(IotaR1Test, DoIt) { const auto element_type = std::get<0>(spec); const int64 num_elements = std::get<1>(spec); XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); - IotaGen(&builder, element_type, num_elements); + Iota(&builder, element_type, num_elements); if (element_type == F32) { ComputeAndCompareR1<float>(&builder, GetR1Expected<float>(num_elements), {}, ErrorSpec{0.0001}); @@ -71,7 +71,7 @@ TEST_P(IotaR2Test, DoIt) { XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); std::vector<int64> dimensions = {42}; dimensions.insert(dimensions.begin() + iota_dim, num_elements); - IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); if (primitive_util::IsFloatingPointType(element_type)) { ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); } else { @@ -98,7 +98,7 @@ TEST_P(IotaR3Test, DoIt) { XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); std::vector<int64> dimensions = {42, 19}; dimensions.insert(dimensions.begin() + iota_dim, num_elements); - IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); if (primitive_util::IsFloatingPointType(element_type)) { ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); } else { |