diff options
Diffstat (limited to 'tensorflow/core/framework/tensor_testutil.h')
-rw-r--r-- | tensorflow/core/framework/tensor_testutil.h | 189 |
1 files changed, 189 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h new file mode 100644 index 0000000000..53d6da0fb2 --- /dev/null +++ b/tensorflow/core/framework/tensor_testutil.h @@ -0,0 +1,189 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace test { + +// Constructs a scalar tensor with 'val'. +template <typename T> +Tensor AsScalar(const T& val) { + Tensor ret(DataTypeToEnum<T>::value, {}); + ret.scalar<T>()() = val; + return ret; +} + +// Constructs a flat tensor with 'vals'. +template <typename T> +Tensor AsTensor(gtl::ArraySlice<T> vals) { + Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())}); + std::copy_n(vals.data(), vals.size(), ret.flat<T>().data()); + return ret; +} + +// Constructs a tensor of "shape" with values "vals". +template <typename T> +Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) { + Tensor ret; + CHECK(ret.CopyFrom(AsTensor(vals), shape)); + return ret; +} + +// Fills in '*tensor' with 'vals'. E.g., +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillValues<float>(&x, {11, 21, 21, 22}); +template <typename T> +void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) { + auto flat = tensor->flat<T>(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + std::copy_n(vals.data(), vals.size(), flat.data()); + } +} + +// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillIota<float>(&x, 1.0); +template <typename T> +void FillIota(Tensor* tensor, const T& val) { + auto flat = tensor->flat<T>(); + std::iota(flat.data(), flat.data() + flat.size(), val); +} + +// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillFn<float>(&x, [](int i)->float { return i*i; }); +template <typename T> +void FillFn(Tensor* tensor, std::function<T(int)> fn) { + auto flat = tensor->flat<T>(); + for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i); +} + +// Expects "x" and "y" are tensors of the same type, same shape, and +// identical values. +template <typename T> +void ExpectTensorEqual(const Tensor& x, const Tensor& y); + +// Expects "x" and "y" are tensors of the same type, same shape, and +// approxmiate equal values, each within "abs_err". +template <typename T> +void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err); + +// Expects "x" and "y" are tensors of the same type (float or double), +// same shape and element-wise difference between x and y is no more +// than atol + rtol * abs(x). +void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6, + double rtol = 1e-6); + +// Implementation details. + +namespace internal { + +template <typename T> +struct is_floating_point_type { + static const bool value = std::is_same<T, float>::value || + std::is_same<T, double>::value || + std::is_same<T, std::complex<float> >::value || + std::is_same<T, std::complex<double> >::value; +}; + +template <typename T> +static void ExpectEqual(const T& a, const T& b) { + EXPECT_EQ(a, b); +} + +template <> +void ExpectEqual<float>(const float& a, const float& b) { + EXPECT_FLOAT_EQ(a, b); +} + +template <> +void ExpectEqual<double>(const double& a, const double& b) { + EXPECT_DOUBLE_EQ(a, b); +} + +template <> +void ExpectEqual<complex64>(const complex64& a, const complex64& b) { + EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), y.dtype()); + ASSERT_TRUE(x.IsSameSize(y)) + << "x.shape [" << x.shape().DebugString() << "] vs " + << "y.shape [ " << y.shape().DebugString() << "]"; +} + +template <typename T, bool is_fp = is_floating_point_type<T>::value> +struct Expector; + +template <typename T> +struct Expector<T, false> { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); + AssertSameTypeDims(x, y); + auto a = x.flat<T>(); + auto b = y.flat<T>(); + for (int i = 0; i < a.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } +}; + +// Partial specialization for float and double. +template <typename T> +struct Expector<T, true> { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); + AssertSameTypeDims(x, y); + auto a = x.flat<T>(); + auto b = y.flat<T>(); + for (int i = 0; i < a.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } + + static void Near(const T& a, const T& b, const double abs_err) { + if (a != b) { // Takes care of inf. + EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b; + } + } + + static void Near(const Tensor& x, const Tensor& y, const double abs_err) { + ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); + AssertSameTypeDims(x, y); + auto a = x.flat<T>(); + auto b = y.flat<T>(); + for (int i = 0; i < a.size(); ++i) { + Near(a(i), b(i), abs_err); + } + } +}; + +} // namespace internal + +template <typename T> +void ExpectTensorEqual(const Tensor& x, const Tensor& y) { + internal::Expector<T>::Equal(x, y); +} + +template <typename T> +void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { + static_assert(internal::is_floating_point_type<T>::value, + "T is not a floating point types."); + internal::Expector<T>::Near(x, y, abs_err); +} + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ |