diff options
author | Bjarke Hammersholt Roune <broune@google.com> | 2018-01-05 16:56:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-05 17:00:14 -0800 |
commit | 6ed75e60c192c487a955bad155d0bf478135e7a5 (patch) | |
tree | 4e7e0a4add0e05b2c9d1e7207fbb610a34921a9f /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | 9f462b27e6df975b1e402c9926d740dfdb9f977a (diff) |
* Make fake argument generation see through ReducePrecision and Convert ops when
determining constraints.
* Generate both positive and negative numbers as a work-around for the CPU reduce
implementation having poor numerical stability.
* You can now VLOG hlo_runner to see the results.
PiperOrigin-RevId: 180993147
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index e8a05cf2b8..bb215be8af 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -29,7 +29,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<FloatT>()); std::minstd_rand0 engine; - std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f); + std::uniform_real_distribution<FloatT> generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate<FloatT>( [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { return generator(engine); @@ -42,7 +42,7 @@ template <> void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) { CHECK_EQ(literal->shape().element_type(), BF16); std::minstd_rand0 engine; - std::uniform_real_distribution<float> generator(0.0f, 1.0f); + std::uniform_real_distribution<float> generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate<bfloat16>( [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { return static_cast<bfloat16>(generator(engine)); @@ -126,6 +126,11 @@ std::vector<HloInstruction*> FindConstrainedUses( fused_uses.end()); } else if (NeedsZeroInitValue(use)) { constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kConvert || + opcode == HloOpcode::kReducePrecision) { + auto converted_uses = FindConstrainedUses(dataflow, *instruction); + constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), + converted_uses.end()); } } } |