diff options
author | 2018-01-26 17:50:03 -0800 | |
---|---|---|
committer | 2018-01-26 17:53:49 -0800 | |
commit | 704361ad3650ebc891167adc41c459ca93392060 (patch) | |
tree | 5152b71fb9ed92168cf439d60827671f2c77e910 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | ca3ac2a464b92f4c0498dfde875f99102a0d410c (diff) |
Create different data for each Literal when creating fake data.
Thread a generator through the functions for creating fake arguments so the same
generator can be reused which avoids repeating the same data patterns for each
argument generated.
Also tweak the position-dependent biasing heuristic to create both positive and
negative numbers for small literals.
PiperOrigin-RevId: 183473588
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 188 |
1 files changed, 105 insertions, 83 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 8b10aef5b8..b060fb13b1 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -24,51 +24,127 @@ namespace xla { namespace { template <typename FloatT> -void PopulateWithRandomFloatingPointData(Literal* literal) { +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<FloatT>()); - std::minstd_rand0 engine; - // Create uniform numbers between 1 and 1.125 ot avoid creating denormal + // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f); + const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate<FloatT>( [&](tensorflow::gtl::ArraySlice<int64> indices) { - // Generate a random uniforma number from -0.0625 and 0.0625 and bias it - // with a position dependent nubmer with mean 0.037109375. These number + // Generate a random uniform number from -0.0625 and 0.0625 and bias it + // with a position dependent number with mean 0.037109375. These number // should allow for long chains of accumulation without being too close - // to zero or to large to accumulate all numbers accurately. - return (generator(engine) - 1.0625) + - static_cast<FloatT>(Product(indices) % 113 - 47) / - static_cast<FloatT>(256.0f); + // to zero or too large to accumulate all numbers accurately. Only do + // this for large literals where the number of elements is much greater + // than 47 otherwise only negative values are produced. + // + // The value is positionally biased using a product of the indices. Add + // one to each index value to avoid collapsing to zero if any of the + // indices are zero. + int64 index_product = 1; + for (int64 i : indices) { + index_product *= (1 + i); + } + const int64 negative_bias = should_index_bias ? 47 : 0; + FloatT index_bias = + static_cast<FloatT>(index_product % 113 - negative_bias) / + static_cast<FloatT>(256.0f); + return (generator(*engine) - 1.0625) + index_bias; })); } // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> -void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) { +void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), BF16); - std::minstd_rand0 engine; std::uniform_real_distribution<float> generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate<bfloat16>( [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { - return static_cast<bfloat16>(generator(engine)); + return static_cast<bfloat16>(generator(*engine)); })); } template <typename IntT> -void PopulateWithRandomIntegralData(Literal* literal) { +void PopulateWithRandomIntegralData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<IntT>()); - std::minstd_rand0 engine; std::uniform_int_distribution<IntT> generator( std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max()); TF_CHECK_OK(literal->Populate<IntT>( [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { - return generator(engine); + return generator(*engine); })); } +// Similar to MakeFakeLiteral but takes a random number generator engine to +// enable reusing the engine across randomly generated literals. +StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( + const Shape& shape, std::minstd_rand0* engine) { + if (ShapeUtil::IsTuple(shape)) { + std::vector<std::unique_ptr<Literal>> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element, + MakeFakeLiteralInternal(element_shape, engine)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); + switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine); + break; + case F32: + PopulateWithRandomFloatingPointData<float>(literal.get(), engine); + break; + case F64: + PopulateWithRandomFloatingPointData<double>(literal.get(), engine); + break; + case S8: + PopulateWithRandomIntegralData<int8>(literal.get(), engine); + break; + case U8: + PopulateWithRandomIntegralData<uint8>(literal.get(), engine); + break; + case S16: + PopulateWithRandomIntegralData<int16>(literal.get(), engine); + break; + case U16: + PopulateWithRandomIntegralData<uint16>(literal.get(), engine); + break; + case S32: + PopulateWithRandomIntegralData<int32>(literal.get(), engine); + break; + case U32: + PopulateWithRandomIntegralData<uint32>(literal.get(), engine); + break; + case S64: + PopulateWithRandomIntegralData<int64>(literal.get(), engine); + break; + case U64: + PopulateWithRandomIntegralData<uint64>(literal.get(), engine); + break; + case PRED: { + std::uniform_int_distribution<int> generator(0, 1); + TF_CHECK_OK(literal->Populate<bool>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(*engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + // Matches binary addition computations. bool LooksLikeSum(const HloComputation& computation) { const HloInstruction* const root = computation.root_instruction(); @@ -95,15 +171,15 @@ bool NeedsZeroInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape) { + const Shape& input_shape, const Shape& slice_shape, + std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector<int32> start_indices(rank); - std::minstd_rand0 engine; for (int i = 0; i < rank; ++i) { const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - ShapeUtil::GetDimension(slice_shape, i); std::uniform_int_distribution<int32> generator(0, upper_bound); - start_indices[i] = generator(engine); + start_indices[i] = generator(*engine); } return Literal::CreateR1<int32>(start_indices); } @@ -150,7 +226,7 @@ std::vector<HloInstruction*> FindConstrainedUses( // zero in the case of init_values for reductions). StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, - const HloInstruction& param) { + const HloInstruction& param, std::minstd_rand0* engine) { HloInstruction* needs_index = nullptr; HloInstruction* needs_zero = nullptr; for (HloInstruction* use : constrained_uses) { @@ -185,93 +261,39 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( } if (needs_index != nullptr) { return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape()); + needs_index->shape(), engine); } else if (needs_zero != nullptr) { return Literal::CreateFromShape(param.shape()); } else { - return MakeFakeLiteral(param.shape()); + return MakeFakeLiteralInternal(param.shape(), engine); } } // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + const HloDataflowAnalysis& dataflow, const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); - return CreateLiteralForConstrainedUses(constrained_uses, param); + return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - std::vector<std::unique_ptr<Literal>> elements; - for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element, - MakeFakeLiteral(element_shape)); - elements.push_back(std::move(element)); - } - return Literal::MakeTupleOwned(std::move(elements)); - } - std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); - switch (shape.element_type()) { - case BF16: - PopulateWithRandomFloatingPointData<bfloat16>(literal.get()); - break; - case F32: - PopulateWithRandomFloatingPointData<float>(literal.get()); - break; - case F64: - PopulateWithRandomFloatingPointData<double>(literal.get()); - break; - case S8: - PopulateWithRandomIntegralData<int8>(literal.get()); - break; - case U8: - PopulateWithRandomIntegralData<uint8>(literal.get()); - break; - case S16: - PopulateWithRandomIntegralData<int16>(literal.get()); - break; - case U16: - PopulateWithRandomIntegralData<uint16>(literal.get()); - break; - case S32: - PopulateWithRandomIntegralData<int32>(literal.get()); - break; - case U32: - PopulateWithRandomIntegralData<uint32>(literal.get()); - break; - case S64: - PopulateWithRandomIntegralData<int64>(literal.get()); - break; - case U64: - PopulateWithRandomIntegralData<uint64>(literal.get()); - break; - case PRED: { - std::uniform_int_distribution<int> generator(0, 1); - std::minstd_rand0 engine; - TF_CHECK_OK(literal->Populate<bool>( - [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { - return generator(engine); - })); - break; - } - default: - return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); - } - return std::move(literal); + std::minstd_rand0 engine; + return MakeFakeLiteralInternal(shape, &engine); } StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( HloModule* const module) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); const auto params = module->entry_computation()->parameter_instructions(); + std::minstd_rand0 engine; std::vector<std::unique_ptr<Literal>> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], - MakeConstrainedArgument(*dataflow, *params[i])); + TF_ASSIGN_OR_RETURN( + arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); } return std::move(arguments); } |