diff options
author | 2018-08-16 18:14:50 -0700 | |
---|---|---|
committer | 2018-08-16 18:18:41 -0700 | |
commit | 86d2a1ef43f883bafe9276bcba14eac2bb0cb637 (patch) | |
tree | 525295bc21bf47d9726eb4ce55e655d6385a6c49 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | bbb3ae0790f042d2bc5f6cce434c75c698d4a978 (diff) |
Various improvements to MakeFakeArguments.
Add an overload which takes a random number generator to enable generation of different random values in sequential calls to MakeFakeArguments. Add a mechanism for generating arrays of unique values (no duplicates) for the key inputs to key/value kSorts. Remove some sorcery in generating float arrays and replace with a uniform distribution. The underlying reason for using this strange distribution no longer exist.
PiperOrigin-RevId: 209083904
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 192 |
1 files changed, 118 insertions, 74 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index faeec657b6..f05421f8e1 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <cmath> + #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -26,89 +28,101 @@ namespace { template <typename FloatT, typename GeneratorT> void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<FloatT>()); - // Create uniform numbers between 1 and 1.125 to avoid creating denormal - // numbers. - std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f); - const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; - TF_CHECK_OK(literal->Populate<FloatT>( - [&](tensorflow::gtl::ArraySlice<int64> indices) { - // Generate a random uniform number from -0.0625 and 0.0625 and bias it - // with a position dependent number with mean 0.037109375. These number - // should allow for long chains of accumulation without being too close - // to zero or too large to accumulate all numbers accurately. Only do - // this for large literals where the number of elements is much greater - // than 47 otherwise only negative values are produced. - // - // The value is positionally biased using a product of the indices. Add - // one to each index value to avoid collapsing to zero if any of the - // indices are zero. - int64 index_product = 1; - for (int64 i : indices) { - index_product *= (1 + i); - } - const int64 negative_bias = should_index_bias ? 47 : 0; - FloatT index_bias = - static_cast<FloatT>(index_product % 113 - negative_bias) / - static_cast<FloatT>(256.0f); - return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias; - })); + if (no_duplicates) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + FloatT next_value = std::numeric_limits<FloatT>::min(); + for (FloatT& value : literal->data<FloatT>()) { + value = next_value; + next_value = + std::nextafter(next_value, std::numeric_limits<FloatT>::max()); + } + std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(), + *engine); + } else { + std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f); + for (FloatT& value : literal->data<FloatT>()) { + value = static_cast<FloatT>(generator(*engine)); + } + } } template <typename FloatT> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine); + PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine, + no_duplicates); } template <> void PopulateWithRandomFloatingPointData<half>(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for half types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine); + std::uniform_real_distribution<float> generator(-0.1f, 0.2f); + for (half& value : literal->data<half>()) { + value = static_cast<half>(generator(*engine)); + } } -// The standard library does not have a case for bfloat16, unsurprisingly, so we -// handle that one specially. template <> void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for bfloat types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), BF16); - std::uniform_real_distribution<float> generator(-0.9f, 1.0f); - TF_CHECK_OK(literal->Populate<bfloat16>( - [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { - return static_cast<bfloat16>(generator(*engine)); - })); + std::uniform_real_distribution<float> generator(-0.1f, 0.2f); + for (bfloat16& value : literal->data<bfloat16>()) { + value = static_cast<bfloat16>(generator(*engine)); + } } template <typename IntT> -void PopulateWithRandomIntegralData(Literal* literal, - std::minstd_rand0* engine) { +void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<IntT>()); - std::uniform_int_distribution<IntT> generator( - std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max()); - TF_CHECK_OK(literal->Populate<IntT>( - [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { - return generator(*engine); - })); + if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) < + std::numeric_limits<IntT>::max()) { + std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0); + std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(), + *engine); + } else { + std::uniform_int_distribution<IntT> generator( + std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max()); + for (IntT& value : literal->data<IntT>()) { + value = generator(*engine); + } + } } // Similar to MakeFakeLiteral but takes a random number generator engine to -// enable reusing the engine across randomly generated literals. +// enable reusing the engine across randomly generated literals. 'no_duplicates' +// indicates that there should be no duplicate values in each generated +// array. This is uniqueness is best-effort only. Some types (half and bfloat16) +// are not supported and uniqueness cannot be guaranteed if the number of +// elements exceeds the number of different values supported by the type. StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine) { + const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { std::vector<std::unique_ptr<Literal>> elements; for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element, - MakeFakeLiteralInternal(element_shape, engine)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr<Literal> element, + MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); @@ -119,40 +133,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( auto literal = MakeUnique<Literal>(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine); + PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine, + no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData<half>(literal.get(), engine); + PopulateWithRandomFloatingPointData<half>(literal.get(), engine, + no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData<float>(literal.get(), engine); + PopulateWithRandomFloatingPointData<float>(literal.get(), engine, + no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData<double>(literal.get(), engine); + PopulateWithRandomFloatingPointData<double>(literal.get(), engine, + no_duplicates); break; case S8: - PopulateWithRandomIntegralData<int8>(literal.get(), engine); + PopulateWithRandomIntegralData<int8>(literal.get(), engine, + no_duplicates); break; case U8: - PopulateWithRandomIntegralData<uint8>(literal.get(), engine); + PopulateWithRandomIntegralData<uint8>(literal.get(), engine, + no_duplicates); break; case S16: - PopulateWithRandomIntegralData<int16>(literal.get(), engine); + PopulateWithRandomIntegralData<int16>(literal.get(), engine, + no_duplicates); break; case U16: - PopulateWithRandomIntegralData<uint16>(literal.get(), engine); + PopulateWithRandomIntegralData<uint16>(literal.get(), engine, + no_duplicates); break; case S32: - PopulateWithRandomIntegralData<int32>(literal.get(), engine); + PopulateWithRandomIntegralData<int32>(literal.get(), engine, + no_duplicates); break; case U32: - PopulateWithRandomIntegralData<uint32>(literal.get(), engine); + PopulateWithRandomIntegralData<uint32>(literal.get(), engine, + no_duplicates); break; case S64: - PopulateWithRandomIntegralData<int64>(literal.get(), engine); + PopulateWithRandomIntegralData<int64>(literal.get(), engine, + no_duplicates); break; case U64: - PopulateWithRandomIntegralData<uint64>(literal.get(), engine); + PopulateWithRandomIntegralData<uint64>(literal.get(), engine, + no_duplicates); break; case PRED: { std::uniform_int_distribution<int> generator(0, 1); @@ -250,6 +276,11 @@ std::vector<HloInstruction*> FindConstrainedUses( auto converted_uses = FindConstrainedUses(dataflow, *instruction); constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); + } else if (opcode == HloOpcode::kSort && + instruction->operand_count() == 2 && op_num == 0) { + // Operand 0 of sort is the array of keys used for key/value + // (two-operand) kSort instructions. + constrained_uses.push_back(instruction); } } } @@ -264,6 +295,7 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector<int64> index_space; + bool no_duplicates = false; bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { @@ -302,16 +334,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( constant_type = GetInitValue(*use->scatter()); break; + case HloOpcode::kSort: + no_duplicates = true; + break; + default: return Unimplemented( "Constrained operand generation not implemented for %s.", use->ToString().c_str()); } } - if (!index_space.empty() && needs_constant) { - return Unimplemented( - "Conflicting operand generation constraints. Dynamically indexes a " - "shape and is the init value of a reduction."); + int constraint_count = 0; + constraint_count += no_duplicates ? 1 : 0; + constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += needs_constant ? 1 : 0; + if (constraint_count > 1) { + return Unimplemented("Conflicting operand generation constraints."); } if (!index_space.empty()) { return MakeRandomIndex(index_space, engine); @@ -324,10 +362,11 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, + /*no_duplicates=*/false); } } else { - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates); } } @@ -345,18 +384,23 @@ StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument( StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape, bool pseudo_random) { auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; - return MakeFakeLiteralInternal(shape, engine.get()); + return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( HloModule* const module, bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; + return MakeFakeArguments(module, engine.get()); +} + +StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( + HloModule* const module, std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; std::vector<std::unique_ptr<Literal>> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get()) - .ValueOrDie(); + arguments[i] = + MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); } return std::move(arguments); } |