aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-08-16 18:14:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 18:18:41 -0700
commit86d2a1ef43f883bafe9276bcba14eac2bb0cb637 (patch)
tree525295bc21bf47d9726eb4ce55e655d6385a6c49 /tensorflow/compiler/xla/tests/test_utils.cc
parentbbb3ae0790f042d2bc5f6cce434c75c698d4a978 (diff)
Various improvements to MakeFakeArguments.
Add an overload which takes a random number generator to enable generation of different random values in sequential calls to MakeFakeArguments. Add a mechanism for generating arrays of unique values (no duplicates) for the key inputs to key/value kSorts. Remove some sorcery in generating float arrays and replace with a uniform distribution. The underlying reason for using this strange distribution no longer exist. PiperOrigin-RevId: 209083904
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc192
1 files changed, 118 insertions, 74 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index faeec657b6..f05421f8e1 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cmath>
+
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -26,89 +28,101 @@ namespace {
template <typename FloatT, typename GeneratorT>
void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
- // Create uniform numbers between 1 and 1.125 to avoid creating denormal
- // numbers.
- std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f);
- const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000;
- TF_CHECK_OK(literal->Populate<FloatT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
- // Generate a random uniform number from -0.0625 and 0.0625 and bias it
- // with a position dependent number with mean 0.037109375. These number
- // should allow for long chains of accumulation without being too close
- // to zero or too large to accumulate all numbers accurately. Only do
- // this for large literals where the number of elements is much greater
- // than 47 otherwise only negative values are produced.
- //
- // The value is positionally biased using a product of the indices. Add
- // one to each index value to avoid collapsing to zero if any of the
- // indices are zero.
- int64 index_product = 1;
- for (int64 i : indices) {
- index_product *= (1 + i);
- }
- const int64 negative_bias = should_index_bias ? 47 : 0;
- FloatT index_bias =
- static_cast<FloatT>(index_product % 113 - negative_bias) /
- static_cast<FloatT>(256.0f);
- return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias;
- }));
+ if (no_duplicates) {
+ // Duplicates may be generated if the number of elements in the literal
+ // exceeds the number of positive values supported by the type.
+ FloatT next_value = std::numeric_limits<FloatT>::min();
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = next_value;
+ next_value =
+ std::nextafter(next_value, std::numeric_limits<FloatT>::max());
+ }
+ std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
+ *engine);
+ } else {
+ std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = static_cast<FloatT>(generator(*engine));
+ }
+ }
}
template <typename FloatT>
void PopulateWithRandomFloatingPointData(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine);
+ PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine,
+ no_duplicates);
}
template <>
void PopulateWithRandomFloatingPointData<half>(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ // no_duplicates is ignored for half types. Unique values can only be
+ // generated for arrays with fewer than ~2**16 elements and no_duplicates is
+ // best-effort anyway.
CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine);
+ std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+ for (half& value : literal->data<half>()) {
+ value = static_cast<half>(generator(*engine));
+ }
}
-// The standard library does not have a case for bfloat16, unsurprisingly, so we
-// handle that one specially.
template <>
void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ // no_duplicates is ignored for bfloat types. Unique values can only be
+ // generated for arrays with fewer than ~2**16 elements and no_duplicates is
+ // best-effort anyway.
CHECK(engine != nullptr);
- CHECK_EQ(literal->shape().element_type(), BF16);
- std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
- TF_CHECK_OK(literal->Populate<bfloat16>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return static_cast<bfloat16>(generator(*engine));
- }));
+ std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+ for (bfloat16& value : literal->data<bfloat16>()) {
+ value = static_cast<bfloat16>(generator(*engine));
+ }
}
template <typename IntT>
-void PopulateWithRandomIntegralData(Literal* literal,
- std::minstd_rand0* engine) {
+void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
- std::uniform_int_distribution<IntT> generator(
- std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
- TF_CHECK_OK(literal->Populate<IntT>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(*engine);
- }));
+ if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
+ std::numeric_limits<IntT>::max()) {
+ std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
+ std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
+ *engine);
+ } else {
+ std::uniform_int_distribution<IntT> generator(
+ std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
+ for (IntT& value : literal->data<IntT>()) {
+ value = generator(*engine);
+ }
+ }
}
// Similar to MakeFakeLiteral but takes a random number generator engine to
-// enable reusing the engine across randomly generated literals.
+// enable reusing the engine across randomly generated literals. 'no_duplicates'
+// indicates that there should be no duplicate values in each generated
+// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
+// are not supported and uniqueness cannot be guaranteed if the number of
+// elements exceeds the number of different values supported by the type.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine) {
+ const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
- MakeFakeLiteralInternal(element_shape, engine));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> element,
+ MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
return LiteralUtil::MakeTupleOwned(std::move(elements));
@@ -119,40 +133,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
auto literal = MakeUnique<Literal>(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+ no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+ no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+ no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+ no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int8>(literal.get(), engine,
+ no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
+ no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int16>(literal.get(), engine,
+ no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
+ no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int32>(literal.get(), engine,
+ no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
+ no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int64>(literal.get(), engine,
+ no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
+ no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
@@ -250,6 +276,11 @@ std::vector<HloInstruction*> FindConstrainedUses(
auto converted_uses = FindConstrainedUses(dataflow, *instruction);
constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
converted_uses.end());
+ } else if (opcode == HloOpcode::kSort &&
+ instruction->operand_count() == 2 && op_num == 0) {
+ // Operand 0 of sort is the array of keys used for key/value
+ // (two-operand) kSort instructions.
+ constrained_uses.push_back(instruction);
}
}
}
@@ -264,6 +295,7 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
+ bool no_duplicates = false;
bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
@@ -302,16 +334,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
constant_type = GetInitValue(*use->scatter());
break;
+ case HloOpcode::kSort:
+ no_duplicates = true;
+ break;
+
default:
return Unimplemented(
"Constrained operand generation not implemented for %s.",
use->ToString().c_str());
}
}
- if (!index_space.empty() && needs_constant) {
- return Unimplemented(
- "Conflicting operand generation constraints. Dynamically indexes a "
- "shape and is the init value of a reduction.");
+ int constraint_count = 0;
+ constraint_count += no_duplicates ? 1 : 0;
+ constraint_count += !index_space.empty() ? 1 : 0;
+ constraint_count += needs_constant ? 1 : 0;
+ if (constraint_count > 1) {
+ return Unimplemented("Conflicting operand generation constraints.");
}
if (!index_space.empty()) {
return MakeRandomIndex(index_space, engine);
@@ -324,10 +362,11 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
- return MakeFakeLiteralInternal(param.shape(), engine);
+ return MakeFakeLiteralInternal(param.shape(), engine,
+ /*no_duplicates=*/false);
}
} else {
- return MakeFakeLiteralInternal(param.shape(), engine);
+ return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates);
}
}
@@ -345,18 +384,23 @@ StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
bool pseudo_random) {
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
- return MakeFakeLiteralInternal(shape, engine.get());
+ return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
HloModule* const module, bool pseudo_random) {
+ auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
+ return MakeFakeArguments(module, engine.get());
+}
+
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ HloModule* const module, std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
- .ValueOrDie();
+ arguments[i] =
+ MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
}
return std::move(arguments);
}