aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/testing.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/testing.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc57
1 files changed, 56 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index d936bd870b..e6645e4941 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -49,6 +48,62 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ std::vector<std::unique_ptr<Literal>> elements;
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
+ MakeFakeLiteral(element_shape));
+ elements.push_back(std::move(element));
+ }
+ return Literal::MakeTupleOwned(std::move(elements));
+ }
+ std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+ std::minstd_rand0 engine;
+ switch (shape.element_type()) {
+ case F32: {
+ std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<float>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+ break;
+ }
+ case S32: {
+ std::uniform_int_distribution<int32> generator(
+ std::numeric_limits<int32>::lowest(),
+ std::numeric_limits<int32>::max());
+ TF_CHECK_OK(literal->Populate<int32>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+ break;
+ }
+ case S64: {
+ std::uniform_int_distribution<int64> generator(
+ std::numeric_limits<int64>::lowest(),
+ std::numeric_limits<int64>::max());
+ TF_CHECK_OK(literal->Populate<int64>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+ break;
+ }
+ case PRED: {
+ std::uniform_int_distribution<int> generator(0, 1);
+ TF_CHECK_OK(literal->Populate<bool>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+ break;
+ }
+ default:
+ return Unimplemented("Unsupported type for fake literal generation: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+ return std::move(literal);
+}
+
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {