aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-01-26 17:50:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 17:53:49 -0800
commit704361ad3650ebc891167adc41c459ca93392060 (patch)
tree5152b71fb9ed92168cf439d60827671f2c77e910 /tensorflow/compiler/xla/tests/test_utils.cc
parentca3ac2a464b92f4c0498dfde875f99102a0d410c (diff)
Create different data for each Literal when creating fake data.
Thread a generator through the functions for creating fake arguments so the same generator can be reused which avoids repeating the same data patterns for each argument generated. Also tweak the position-dependent biasing heuristic to create both positive and negative numbers for small literals. PiperOrigin-RevId: 183473588
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc188
1 files changed, 105 insertions, 83 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 8b10aef5b8..b060fb13b1 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -24,51 +24,127 @@ namespace xla {
namespace {
template <typename FloatT>
-void PopulateWithRandomFloatingPointData(Literal* literal) {
+void PopulateWithRandomFloatingPointData(Literal* literal,
+ std::minstd_rand0* engine) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
- std::minstd_rand0 engine;
- // Create uniform numbers between 1 and 1.125 ot avoid creating denormal
+ // Create uniform numbers between 1 and 1.125 to avoid creating denormal
// numbers.
std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f);
+ const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000;
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> indices) {
- // Generate a random uniforma number from -0.0625 and 0.0625 and bias it
- // with a position dependent nubmer with mean 0.037109375. These number
+ // Generate a random uniform number from -0.0625 and 0.0625 and bias it
+ // with a position dependent number with mean 0.037109375. These number
// should allow for long chains of accumulation without being too close
- // to zero or to large to accumulate all numbers accurately.
- return (generator(engine) - 1.0625) +
- static_cast<FloatT>(Product(indices) % 113 - 47) /
- static_cast<FloatT>(256.0f);
+ // to zero or too large to accumulate all numbers accurately. Only do
+ // this for large literals where the number of elements is much greater
+ // than 47 otherwise only negative values are produced.
+ //
+ // The value is positionally biased using a product of the indices. Add
+ // one to each index value to avoid collapsing to zero if any of the
+ // indices are zero.
+ int64 index_product = 1;
+ for (int64 i : indices) {
+ index_product *= (1 + i);
+ }
+ const int64 negative_bias = should_index_bias ? 47 : 0;
+ FloatT index_bias =
+ static_cast<FloatT>(index_product % 113 - negative_bias) /
+ static_cast<FloatT>(256.0f);
+ return (generator(*engine) - 1.0625) + index_bias;
}));
}
// The standard library does not have a case for bfloat16, unsurprisingly, so we
// handle that one specially.
template <>
-void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
+void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
+ std::minstd_rand0* engine) {
CHECK_EQ(literal->shape().element_type(), BF16);
- std::minstd_rand0 engine;
std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
TF_CHECK_OK(literal->Populate<bfloat16>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return static_cast<bfloat16>(generator(engine));
+ return static_cast<bfloat16>(generator(*engine));
}));
}
template <typename IntT>
-void PopulateWithRandomIntegralData(Literal* literal) {
+void PopulateWithRandomIntegralData(Literal* literal,
+ std::minstd_rand0* engine) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
- std::minstd_rand0 engine;
std::uniform_int_distribution<IntT> generator(
std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
TF_CHECK_OK(literal->Populate<IntT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
+ return generator(*engine);
}));
}
+// Similar to MakeFakeLiteral but takes a random number generator engine to
+// enable reusing the engine across randomly generated literals.
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
+ const Shape& shape, std::minstd_rand0* engine) {
+ if (ShapeUtil::IsTuple(shape)) {
+ std::vector<std::unique_ptr<Literal>> elements;
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
+ MakeFakeLiteralInternal(element_shape, engine));
+ elements.push_back(std::move(element));
+ }
+ return Literal::MakeTupleOwned(std::move(elements));
+ }
+ std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+ switch (shape.element_type()) {
+ case BF16:
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine);
+ break;
+ case F32:
+ PopulateWithRandomFloatingPointData<float>(literal.get(), engine);
+ break;
+ case F64:
+ PopulateWithRandomFloatingPointData<double>(literal.get(), engine);
+ break;
+ case S8:
+ PopulateWithRandomIntegralData<int8>(literal.get(), engine);
+ break;
+ case U8:
+ PopulateWithRandomIntegralData<uint8>(literal.get(), engine);
+ break;
+ case S16:
+ PopulateWithRandomIntegralData<int16>(literal.get(), engine);
+ break;
+ case U16:
+ PopulateWithRandomIntegralData<uint16>(literal.get(), engine);
+ break;
+ case S32:
+ PopulateWithRandomIntegralData<int32>(literal.get(), engine);
+ break;
+ case U32:
+ PopulateWithRandomIntegralData<uint32>(literal.get(), engine);
+ break;
+ case S64:
+ PopulateWithRandomIntegralData<int64>(literal.get(), engine);
+ break;
+ case U64:
+ PopulateWithRandomIntegralData<uint64>(literal.get(), engine);
+ break;
+ case PRED: {
+ std::uniform_int_distribution<int> generator(0, 1);
+ TF_CHECK_OK(literal->Populate<bool>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(*engine);
+ }));
+ break;
+ }
+ default:
+ return Unimplemented("Unsupported type for fake literal generation: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+ return std::move(literal);
+}
+
// Matches binary addition computations.
bool LooksLikeSum(const HloComputation& computation) {
const HloInstruction* const root = computation.root_instruction();
@@ -95,15 +171,15 @@ bool NeedsZeroInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
- const Shape& input_shape, const Shape& slice_shape) {
+ const Shape& input_shape, const Shape& slice_shape,
+ std::minstd_rand0* engine) {
const int64 rank = ShapeUtil::Rank(input_shape);
std::vector<int32> start_indices(rank);
- std::minstd_rand0 engine;
for (int i = 0; i < rank; ++i) {
const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
ShapeUtil::GetDimension(slice_shape, i);
std::uniform_int_distribution<int32> generator(0, upper_bound);
- start_indices[i] = generator(engine);
+ start_indices[i] = generator(*engine);
}
return Literal::CreateR1<int32>(start_indices);
}
@@ -150,7 +226,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
// zero in the case of init_values for reductions).
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
- const HloInstruction& param) {
+ const HloInstruction& param, std::minstd_rand0* engine) {
HloInstruction* needs_index = nullptr;
HloInstruction* needs_zero = nullptr;
for (HloInstruction* use : constrained_uses) {
@@ -185,93 +261,39 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
}
if (needs_index != nullptr) {
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
- needs_index->shape());
+ needs_index->shape(), engine);
} else if (needs_zero != nullptr) {
return Literal::CreateFromShape(param.shape());
} else {
- return MakeFakeLiteral(param.shape());
+ return MakeFakeLiteralInternal(param.shape(), engine);
}
}
// Given a module entry parameter, use the dataflow analysis to see if a
// special case literal must be created, or if we can generate fake data.
StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
- const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
+ const HloDataflowAnalysis& dataflow, const HloInstruction& param,
+ std::minstd_rand0* engine) {
const auto constrained_uses = FindConstrainedUses(dataflow, param);
- return CreateLiteralForConstrainedUses(constrained_uses, param);
+ return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
}
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
- if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
- for (const Shape& element_shape : shape.tuple_shapes()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
- MakeFakeLiteral(element_shape));
- elements.push_back(std::move(element));
- }
- return Literal::MakeTupleOwned(std::move(elements));
- }
- std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
- switch (shape.element_type()) {
- case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get());
- break;
- case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get());
- break;
- case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get());
- break;
- case S8:
- PopulateWithRandomIntegralData<int8>(literal.get());
- break;
- case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get());
- break;
- case S16:
- PopulateWithRandomIntegralData<int16>(literal.get());
- break;
- case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get());
- break;
- case S32:
- PopulateWithRandomIntegralData<int32>(literal.get());
- break;
- case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get());
- break;
- case S64:
- PopulateWithRandomIntegralData<int64>(literal.get());
- break;
- case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get());
- break;
- case PRED: {
- std::uniform_int_distribution<int> generator(0, 1);
- std::minstd_rand0 engine;
- TF_CHECK_OK(literal->Populate<bool>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- default:
- return Unimplemented("Unsupported type for fake literal generation: %s",
- ShapeUtil::HumanString(shape).c_str());
- }
- return std::move(literal);
+ std::minstd_rand0 engine;
+ return MakeFakeLiteralInternal(shape, &engine);
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
HloModule* const module) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module));
const auto params = module->entry_computation()->parameter_instructions();
+ std::minstd_rand0 engine;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- TF_ASSIGN_OR_RETURN(arguments[i],
- MakeConstrainedArgument(*dataflow, *params[i]));
+ TF_ASSIGN_OR_RETURN(
+ arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine));
}
return std::move(arguments);
}