aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc74
1 files changed, 33 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 3ae31191a0..5155f0c652 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
// are not supported and uniqueness cannot be guaranteed if the number of
// elements exceeds the number of different values supported by the type.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
+StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> element,
+ Literal element,
MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
@@ -131,60 +132,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<half>(&literal, engine,
no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<float>(&literal, engine,
no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<double>(&literal, engine,
no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(
- literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) {
+ literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
return generator(*engine);
}));
break;
@@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space,
- std::minstd_rand0* engine) {
+Literal MakeRandomIndex(absl::Span<const int64> index_space,
+ std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
@@ -293,7 +286,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
// no constrained uses in the dataflow graph. If such constraints exist,
// generate a constrained literal (either bounded in the case of indices, or
// zero in the case of init_values for reductions).
-StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
+StatusOr<Literal> CreateLiteralForConstrainedUses(
const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
@@ -358,9 +351,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
- return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type());
case ConstantType::kOne:
- return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type());
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
@@ -374,34 +367,33 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
// Given a module entry parameter, use the dataflow analysis to see if a
// special case literal must be created, or if we can generate fake data.
-StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
- const HloDataflowAnalysis& dataflow, const HloInstruction& param,
- std::minstd_rand0* engine) {
+StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
+ const HloInstruction& param,
+ std::minstd_rand0* engine) {
const auto constrained_uses = FindConstrainedUses(dataflow, param);
return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
}
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random) {
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeArguments(module, engine.get());
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- std::vector<std::unique_ptr<Literal>> arguments(params.size());
+ std::vector<Literal> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
arguments[i] =
MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();