diff options
author | Mark Heffernan <meheff@google.com> | 2017-11-08 15:35:27 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:37 -0800 |
commit | 64d2636e2946772d4b1531ec91b389110a2787b7 (patch) | |
tree | 71fdb3d248e9b1f24ab649e1d21eac490e6782bd /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | 2ba34173fad0d5b7d986baeb8171bdc6afdcd7bb (diff) |
Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.
Also, optionally print layout in Literal::ToString.
PiperOrigin-RevId: 175076277
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc new file mode 100644 index 0000000000..cdd3d66bbb --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_utils.h" + +#include "tensorflow/compiler/xla/primitive_util.h" + +namespace xla { + +namespace { + +template <typename FloatT> +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType<FloatT>()); + std::minstd_rand0 engine; + std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate<FloatT>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); +} + +template <typename IntT> +void PopulateWithRandomIntegralData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType<IntT>()); + std::minstd_rand0 engine; + std::uniform_int_distribution<IntT> generator( + std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max()); + TF_CHECK_OK(literal->Populate<IntT>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); +} + +} // namespace + +StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + std::vector<std::unique_ptr<Literal>> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element, + MakeFakeLiteral(element_shape)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); + switch (shape.element_type()) { + case F32: + PopulateWithRandomFloatingPointData<float>(literal.get()); + break; + case F64: + PopulateWithRandomFloatingPointData<double>(literal.get()); + break; + case S8: + PopulateWithRandomIntegralData<int8>(literal.get()); + break; + case U8: + PopulateWithRandomIntegralData<uint8>(literal.get()); + break; + case S16: + PopulateWithRandomIntegralData<int16>(literal.get()); + break; + case U16: + PopulateWithRandomIntegralData<uint16>(literal.get()); + break; + case S32: + PopulateWithRandomIntegralData<int32>(literal.get()); + break; + case U32: + PopulateWithRandomIntegralData<uint32>(literal.get()); + break; + case S64: + PopulateWithRandomIntegralData<int64>(literal.get()); + break; + case U64: + PopulateWithRandomIntegralData<uint64>(literal.get()); + break; + case PRED: { + std::uniform_int_distribution<int> generator(0, 1); + std::minstd_rand0 engine; + TF_CHECK_OK(literal->Populate<bool>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return generator(engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + +StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( + const HloModule& module) { + std::vector<std::unique_ptr<Literal>> arguments; + for (const ShapeLayout& shape_layout : + module.config().entry_computation_layout().parameter_layouts()) { + TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape())); + arguments.push_back(std::move(literal)); + } + return std::move(arguments); +} + +} // namespace xla |