diff options
author | 2018-05-29 21:10:43 -0700 | |
---|---|---|
committer | 2018-05-29 21:13:10 -0700 | |
commit | 3f2ba2edf62dc394cfcb4b2606f1638389aa92e2 (patch) | |
tree | b40e55e676278c07bcb51f593dd4c787f02f0db7 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | a364bc51405c0dbebe97c723fba8f877696205cc (diff) |
Add features to HloRunner for running while leaving buffers on the device and add option to test_utils for generating more-boring data much faster.
PiperOrigin-RevId: 198502753
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index de18651388..dd7c541733 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -26,6 +26,7 @@ namespace { template <typename FloatT, typename GeneratorT> void PopulateWithRandomFloatingPointDataImpl(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<FloatT>()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal @@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal, template <typename FloatT> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine); } template <> void PopulateWithRandomFloatingPointData<half>(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine); } @@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData<half>(Literal* literal, template <> void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution<float> generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate<bfloat16>( @@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal, template <typename IntT> void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<IntT>()); std::uniform_int_distribution<IntT> generator( @@ -107,6 +112,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } auto literal = MakeUnique<Literal>(shape); switch (shape.element_type()) { case BF16: @@ -201,11 +209,13 @@ std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector<int32> start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution<int32> generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution<int32> generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1<int32>(start_indices); } @@ -321,20 +331,21 @@ StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument( } // namespace -StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; std::vector<std::unique_ptr<Literal>> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } |