diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 74 |
1 files changed, 33 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 3ae31191a0..5155f0c652 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, // array. This is uniqueness is best-effort only. Some types (half and bfloat16) // are not supported and uniqueness cannot be guaranteed if the number of // elements exceeds the number of different values supported by the type. -StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { +StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape, + std::minstd_rand0* engine, + bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { - std::vector<std::unique_ptr<Literal>> elements; + std::vector<Literal> elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( - std::unique_ptr<Literal> element, + Literal element, MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } @@ -131,60 +132,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = absl::make_unique<Literal>(shape); + Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine, + PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData<half>(literal.get(), engine, + PopulateWithRandomFloatingPointData<half>(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData<float>(literal.get(), engine, + PopulateWithRandomFloatingPointData<float>(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData<double>(literal.get(), engine, + PopulateWithRandomFloatingPointData<double>(&literal, engine, no_duplicates); break; case S8: - PopulateWithRandomIntegralData<int8>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates); break; case U8: - PopulateWithRandomIntegralData<uint8>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates); break; case S16: - PopulateWithRandomIntegralData<int16>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates); break; case U16: - PopulateWithRandomIntegralData<uint16>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates); break; case S32: - PopulateWithRandomIntegralData<int32>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates); break; case U32: - PopulateWithRandomIntegralData<uint32>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates); break; case S64: - PopulateWithRandomIntegralData<int64>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates); break; case U64: - PopulateWithRandomIntegralData<uint64>(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates); break; case PRED: { std::uniform_int_distribution<int> generator(0, 1); TF_CHECK_OK( - literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) { + literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) { return generator(*engine); })); break; @@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space, - std::minstd_rand0* engine) { +Literal MakeRandomIndex(absl::Span<const int64> index_space, + std::minstd_rand0* engine) { std::vector<int32> start_indices(index_space.size()); if (engine != nullptr) { for (int i = 0; i < index_space.size(); ++i) { @@ -293,7 +286,7 @@ std::vector<HloInstruction*> FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( +StatusOr<Literal> CreateLiteralForConstrainedUses( const absl::Span<HloInstruction* const> constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector<int64> index_space; @@ -358,9 +351,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: - return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()); case ConstantType::kOne: - return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. @@ -374,34 +367,33 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param, - std::minstd_rand0* engine) { +StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, + const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace -StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape, - bool pseudo_random) { +StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } -StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( - HloModule* const module, bool pseudo_random) { +StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module, + bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr; return MakeFakeArguments(module, engine.get()); } -StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine) { +StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::vector<std::unique_ptr<Literal>> arguments(params.size()); + std::vector<Literal> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); |