diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/testing.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/lib/testing.cc | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index d936bd870b..e6645e4941 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -49,6 +48,62 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape, } // namespace +StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + std::vector<std::unique_ptr<Literal>> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element, + MakeFakeLiteral(element_shape)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); + std::minstd_rand0 engine; + switch (shape.element_type()) { + case F32: { + std::uniform_real_distribution<float> generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate<float>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); + break; + } + case S32: { + std::uniform_int_distribution<int32> generator( + std::numeric_limits<int32>::lowest(), + std::numeric_limits<int32>::max()); + TF_CHECK_OK(literal->Populate<int32>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); + break; + } + case S64: { + std::uniform_int_distribution<int64> generator( + std::numeric_limits<int64>::lowest(), + std::numeric_limits<int64>::max()); + TF_CHECK_OK(literal->Populate<int64>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); + break; + } + case PRED: { + std::uniform_int_distribution<int> generator(0, 1); + TF_CHECK_OK(literal->Populate<bool>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape, Client* client) { if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { |