aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2018-05-29 21:10:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 21:13:10 -0700
commit3f2ba2edf62dc394cfcb4b2606f1638389aa92e2 (patch)
treeb40e55e676278c07bcb51f593dd4c787f02f0db7 /tensorflow/compiler/xla/tests/test_utils.cc
parenta364bc51405c0dbebe97c723fba8f877696205cc (diff)
Add features to HloRunner for running while leaving buffers on the device and add option to test_utils for generating more-boring data much faster.
PiperOrigin-RevId: 198502753
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc35
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index de18651388..dd7c541733 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -26,6 +26,7 @@ namespace {
template <typename FloatT, typename GeneratorT>
void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
std::minstd_rand0* engine) {
+ CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
// Create uniform numbers between 1 and 1.125 to avoid creating denormal
@@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
template <typename FloatT>
void PopulateWithRandomFloatingPointData(Literal* literal,
std::minstd_rand0* engine) {
+ CHECK(engine != nullptr);
PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine);
}
template <>
void PopulateWithRandomFloatingPointData<half>(Literal* literal,
std::minstd_rand0* engine) {
+ CHECK(engine != nullptr);
PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine);
}
@@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData<half>(Literal* literal,
template <>
void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
std::minstd_rand0* engine) {
+ CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(), BF16);
std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
TF_CHECK_OK(literal->Populate<bfloat16>(
@@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal,
std::minstd_rand0* engine) {
+ CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
std::uniform_int_distribution<IntT> generator(
@@ -107,6 +112,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
}
return Literal::MakeTupleOwned(std::move(elements));
}
+ if (engine == nullptr) {
+ return Literal::CreateFromShape(shape);
+ }
auto literal = MakeUnique<Literal>(shape);
switch (shape.element_type()) {
case BF16:
@@ -201,11 +209,13 @@ std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
std::minstd_rand0* engine) {
const int64 rank = ShapeUtil::Rank(input_shape);
std::vector<int32> start_indices(rank);
- for (int i = 0; i < rank; ++i) {
- const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
- ShapeUtil::GetDimension(slice_shape, i);
- std::uniform_int_distribution<int32> generator(0, upper_bound);
- start_indices[i] = generator(*engine);
+ if (engine != nullptr) {
+ for (int i = 0; i < rank; ++i) {
+ const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i);
+ std::uniform_int_distribution<int32> generator(0, upper_bound);
+ start_indices[i] = generator(*engine);
+ }
}
return Literal::CreateR1<int32>(start_indices);
}
@@ -321,20 +331,21 @@ StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
- std::minstd_rand0 engine;
- return MakeFakeLiteralInternal(shape, &engine);
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
+ bool pseudo_random) {
+ auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
+ return MakeFakeLiteralInternal(shape, engine.get());
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module) {
+ HloModule* const module, bool pseudo_random) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- std::minstd_rand0 engine;
+ auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- TF_ASSIGN_OR_RETURN(
- arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine));
+ TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
+ *dataflow, *params[i], engine.get()));
}
return std::move(arguments);
}