aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2018-01-05 16:56:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 17:00:14 -0800
commit6ed75e60c192c487a955bad155d0bf478135e7a5 (patch)
tree4e7e0a4add0e05b2c9d1e7207fbb610a34921a9f /tensorflow/compiler/xla/tests/test_utils.cc
parent9f462b27e6df975b1e402c9926d740dfdb9f977a (diff)
* Make fake argument generation see through ReducePrecision and Convert ops when
determining constraints. * Generate both positive and negative numbers as a work-around for the CPU reduce implementation having poor numerical stability. * You can now VLOG hlo_runner to see the results. PiperOrigin-RevId: 180993147
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc9
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index e8a05cf2b8..bb215be8af 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -29,7 +29,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
std::minstd_rand0 engine;
- std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
+ std::uniform_real_distribution<FloatT> generator(-0.9f, 1.0f);
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
@@ -42,7 +42,7 @@ template <>
void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
CHECK_EQ(literal->shape().element_type(), BF16);
std::minstd_rand0 engine;
- std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
TF_CHECK_OK(literal->Populate<bfloat16>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return static_cast<bfloat16>(generator(engine));
@@ -126,6 +126,11 @@ std::vector<HloInstruction*> FindConstrainedUses(
fused_uses.end());
} else if (NeedsZeroInitValue(use)) {
constrained_uses.push_back(instruction);
+ } else if (opcode == HloOpcode::kConvert ||
+ opcode == HloOpcode::kReducePrecision) {
+ auto converted_uses = FindConstrainedUses(dataflow, *instruction);
+ constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
+ converted_uses.end());
}
}
}