diff options
18 files changed, 276 insertions, 77 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 5e92df2d63..819d324927 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -466,6 +466,19 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.add_dimensions(iota_dimension); + return AddInstruction(std::move(instr), HloOpcode::kIota); + }); +} + +XlaOp XlaBuilder::IotaGen(PrimitiveType type, int64 size) { + return IotaGen(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); +} + XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice<XlaOp> operands) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { @@ -3023,10 +3036,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { - HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); - return builder->ReportErrorOrReturn( - builder->AddInstruction(std::move(instr), HloOpcode::kIota)); + return builder->IotaGen(type, size); +} + +XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { + return builder->IotaGen(shape, iota_dimension); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index e9d5d3943c..193d8ed071 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -800,6 +800,12 @@ class XlaBuilder { // entry was NaN. XlaOp IsFinite(const XlaOp& operand); + // Enqueues an iota operation onto the computation. + XlaOp IotaGen(const Shape& shape, int64 iota_dimension); + + // Enqueues a rank-1 iota operation onto the computation. + XlaOp IotaGen(PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, @@ -1304,6 +1310,8 @@ class XlaBuilder { friend XlaOp IsFinite(const XlaOp& operand); // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota // in xla/client/lib/numeric.h with this (renamed to xla::Iota). + friend XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, + int64 iota_dimension); friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); @@ -1960,6 +1968,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, // entry was NaN. XlaOp IsFinite(const XlaOp& operand); +// Enqueues an iota operation onto the computation. +XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 716c75da39..b68785949c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -230,6 +230,7 @@ cc_library( hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_query", ":shape_inference", "//tensorflow/compiler/xla:literal", @@ -2290,6 +2291,7 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2682,6 +2684,7 @@ cc_library( hdrs = ["elemental_ir_emitter.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 903e73f606..460363e18f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2437,11 +2437,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } -Status IrEmitter::HandleIota(HloInstruction* iota) { - // TODO(b/64798317): implement iota on CPU. - return Unimplemented("Iota is not implemented on CPU."); -} - Status IrEmitter::HandleRng(HloInstruction* rng) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index ec68710d3f..f98891246b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -157,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; - Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 61f6055fc9..813e93fafa 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -28,6 +28,8 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -2095,6 +2097,50 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), hlo->dimensions(), b_)); }; + case HloOpcode::kIota: + return [this, hlo]( + const IrArray::Index& target_index) -> StatusOr<llvm::Value*> { + auto* iota = Cast<HloIotaInstruction>(hlo); + PrimitiveType element_type = iota->shape().element_type(); + IrArray::Index elem_index = + ShapeUtil::Rank(iota->shape()) > 1 + ? target_index.SourceIndexOfBroadcast( + iota->shape(), + ShapeUtil::MakeShapeWithDescendingLayout( + element_type, + {iota->shape().dimensions(iota->iota_dimension())}), + {iota->iota_dimension()}, b_) + : target_index; + llvm::Value* elem_index_linear = elem_index.linear(); + if (elem_index_linear == nullptr) { + std::vector<int64> iota_bound = { + iota->shape().dimensions(iota->iota_dimension())}; + elem_index_linear = elem_index.Linearize(iota_bound, b_); + } + if (ShapeUtil::ElementIsIntegral(iota->shape())) { + return b_->CreateIntCast( + elem_index_linear, + llvm_ir::PrimitiveTypeToIrType(element_type, module_), + /*isSigned=*/false); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape())) + << element_type; + llvm::Type* float_ir_type; + if (element_type == BF16) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(element_type, module_); + } + llvm::Value* float_val = + b_->CreateUIToFP(elem_index_linear, float_ir_type); + if (element_type == BF16) { + return EmitF32ToBF16(float_val, b_); + } else { + return float_val; + } + } + }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr<llvm::Value*> { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a620cebe04..bdf6aadde6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -746,11 +746,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleIota(HloInstruction*) { - // TODO(b/64798317): implement iota on GPU. - return Unimplemented("Iota is not implemented on GPU."); -} - StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e096a07704..3673b9f58d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -97,7 +97,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b6566ebefe..f682e69ee9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -21,7 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -2493,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same<NativeT, float>::value || std::is_same<NativeT, int32>::value || std::is_same<NativeT, uint32>::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = absl::make_unique<Literal>(iota->shape()); - auto data = result->data<ReturnT>(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast<HloIotaInstruction>(instruction); + std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1<NativeT>(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result->Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template <typename NativeT, diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c77699a06f..ed4e159910 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -428,6 +428,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( computations(0), *scatter_dimension_numbers); break; } + case HloOpcode::kIota: + TF_RET_CHECK(proto.dimensions_size() <= 1) + << "Iota instruction should have at most 1 dimension but sees " + << proto.dimensions_size(); + instruction = CreateIota(proto.shape(), proto.dimensions(0)); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -490,8 +496,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota( - const Shape& shape) { - return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + const Shape& shape, int64 iota_dimension) { + return absl::make_unique<HloIotaInstruction>(shape, iota_dimension); } /* static */ std::unique_ptr<HloInstruction> @@ -2053,13 +2059,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=", - StrJoin(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, - PrintName(computation->name(), options)); - }))); + extra.push_back(StrCat( + "calls=", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }))); } } else if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kFullBodies) { @@ -3218,12 +3223,12 @@ HloInstruction::source_target_pairs() const { } string HloInstruction::cross_replica_sum_barrier() const { - return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); + return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( - barrier); + return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( + barrier); } absl::optional<int64> HloInstruction::all_reduce_id() const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b393635e9d..4a424cebc0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -350,7 +350,8 @@ class HloInstruction { std::unique_ptr<Literal> literal); // Creates an Iota instruction. - static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape); + static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, + int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr<HloInstruction> CreateGetTupleElement( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index b93c758937..ffc74cfedd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2155,4 +2155,34 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl( scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloIotaInstruction&>(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 29b187300d..ee6e337b6a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1279,6 +1279,30 @@ class HloScatterInstruction : public HloInstruction { std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; }; +class HloIotaInstruction : public HloInstruction { + public: + explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + int64 iota_dimension() const { return iota_dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + const int64 iota_dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index c7a766f4e0..eae4508b24 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -562,11 +562,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { + optional<tensorflow::int64> iota_dimension; + attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, + &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + instruction = builder->AddInstruction( + HloInstruction::CreateIota(shape, *iota_dimension)); break; } // Unary ops. diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b1ef288b8e..ba07ec432e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1116,7 +1116,7 @@ ENTRY CollectivePermute { R"(HloModule iota ENTRY Iota { - ROOT iota = f32[100]{0} iota() + ROOT iota = f32[100]{0} iota(), iota_dimension=0 } )" diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0ed2c3b449..f1b29c2559 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -265,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } -Status ShapeVerifier::HandleIota(HloInstruction* iota) { - return ShapeUtil::Rank(iota->shape()) == 1 - ? Status::OK() - : InternalError("Iota only supports arrays of rank 1."); +Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + auto* iota = Cast<HloIotaInstruction>(instruction); + const int64 rank = ShapeUtil::Rank(iota->shape()); + if (rank == 0) { + return InternalError("Iota does not support scalars."); + } + int64 iota_dimension = iota->iota_dimension(); + if (iota_dimension >= rank) { + return InternalError( + "The iota dimension cannot go beyond the operation rank."); + } + return Status::OK(); } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d5e3b747e7..a0829b0d02 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2131,19 +2131,13 @@ xla_test( xla_test( name = "iota_test", srcs = ["iota_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], + shard_count = 30, tags = [ "enable_for_xla_interpreter", ], deps = [ ":client_library_test_base", - ":literal_test_util", ":xla_internal_test_main", - "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 17ac95ae01..07c3c6b878 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -23,40 +23,95 @@ limitations under the License. namespace xla { namespace { -class IotaTest : public ClientLibraryTestBase { - public: - explicit IotaTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) {} - template <typename T> - std::vector<T> GetExpected(const int64 num_elements) { - std::vector<T> result(num_elements); - std::iota(result.begin(), result.end(), 0); - return result; +template <typename T> +std::vector<T> GetR1Expected(const int64 num_elements) { + std::vector<T> result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; +} + +class IotaR1Test + : public ClientLibraryTestBase, + public ::testing::WithParamInterface<std::tuple<PrimitiveType, int>> {}; + +TEST_P(IotaR1Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + IotaGen(&builder, element_type, num_elements); + if (element_type == F32) { + ComputeAndCompareR1<float>(&builder, GetR1Expected<float>(num_elements), {}, + ErrorSpec{0.0001}); + } else if (element_type == U32) { + ComputeAndCompareR1<uint32>(&builder, GetR1Expected<uint32>(num_elements), + {}); + } else { + CHECK_EQ(element_type, S32); + ComputeAndCompareR1<int32>(&builder, GetR1Expected<int32>(num_elements), + {}); } -}; - -XLA_TEST_F(IotaTest, SimpleR1) { - for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { - { - XlaBuilder builder(TestName() + "_f32"); - IotaGen(&builder, F32, num_elements); - ComputeAndCompareR1<float>(&builder, GetExpected<float>(num_elements), {}, - ErrorSpec{0.0001}); - } - { - XlaBuilder builder(TestName() + "_u32"); - IotaGen(&builder, U32, num_elements); - ComputeAndCompareR1<uint32>(&builder, GetExpected<uint32>(num_elements), - {}); - } - { - XlaBuilder builder(TestName() + "_s32"); - IotaGen(&builder, S32, num_elements); - ComputeAndCompareR1<int32>(&builder, GetExpected<int32>(num_elements), - {}); - } +} + +INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test, + ::testing::Combine(::testing::Values(F32, U32, S32), + ::testing::Range(/*start=*/10, + /*end=*/10001, + /*step=*/10))); + +class IotaR2Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple<PrimitiveType, int, int>> {}; + +TEST_P(IotaR2Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector<int64> dimensions = {42}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); } } +INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1))); + +class IotaR3Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple<PrimitiveType, int, int>> {}; + +TEST_P(IotaR3Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector<int64> dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1, 2))); + } // namespace } // namespace xla |