#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ #define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/tensor.h" #include namespace tensorflow { namespace test { // Constructs a scalar tensor with 'val'. template Tensor AsScalar(const T& val) { Tensor ret(DataTypeToEnum::value, {}); ret.scalar()() = val; return ret; } // Constructs a flat tensor with 'vals'. template Tensor AsTensor(gtl::ArraySlice vals) { Tensor ret(DataTypeToEnum::value, {static_cast(vals.size())}); std::copy_n(vals.data(), vals.size(), ret.flat().data()); return ret; } // Constructs a tensor of "shape" with values "vals". template Tensor AsTensor(gtl::ArraySlice vals, const TensorShape& shape) { Tensor ret; CHECK(ret.CopyFrom(AsTensor(vals), shape)); return ret; } // Fills in '*tensor' with 'vals'. E.g., // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); // test::FillValues(&x, {11, 21, 21, 22}); template void FillValues(Tensor* tensor, gtl::ArraySlice vals) { auto flat = tensor->flat(); CHECK_EQ(flat.size(), vals.size()); if (flat.size() > 0) { std::copy_n(vals.data(), vals.size(), flat.data()); } } // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ... // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); // test::FillIota(&x, 1.0); template void FillIota(Tensor* tensor, const T& val) { auto flat = tensor->flat(); std::iota(flat.data(), flat.data() + flat.size(), val); } // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ... // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); // test::FillFn(&x, [](int i)->float { return i*i; }); template void FillFn(Tensor* tensor, std::function fn) { auto flat = tensor->flat(); for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i); } // Expects "x" and "y" are tensors of the same type, same shape, and // identical values. template void ExpectTensorEqual(const Tensor& x, const Tensor& y); // Expects "x" and "y" are tensors of the same type, same shape, and // approxmiate equal values, each within "abs_err". template void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err); // Expects "x" and "y" are tensors of the same type (float or double), // same shape and element-wise difference between x and y is no more // than atol + rtol * abs(x). void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6, double rtol = 1e-6); // Implementation details. namespace internal { template struct is_floating_point_type { static const bool value = std::is_same::value || std::is_same::value || std::is_same >::value || std::is_same >::value; }; template static void ExpectEqual(const T& a, const T& b) { EXPECT_EQ(a, b); } template <> void ExpectEqual(const float& a, const float& b) { EXPECT_FLOAT_EQ(a, b); } template <> void ExpectEqual(const double& a, const double& b) { EXPECT_DOUBLE_EQ(a, b); } template <> void ExpectEqual(const complex64& a, const complex64& b) { EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b; EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b; } inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), y.dtype()); ASSERT_TRUE(x.IsSameSize(y)) << "x.shape [" << x.shape().DebugString() << "] vs " << "y.shape [ " << y.shape().DebugString() << "]"; } template ::value> struct Expector; template struct Expector { static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } static void Equal(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); auto a = x.flat(); auto b = y.flat(); for (int i = 0; i < a.size(); ++i) { ExpectEqual(a(i), b(i)); } } }; // Partial specialization for float and double. template struct Expector { static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } static void Equal(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); auto a = x.flat(); auto b = y.flat(); for (int i = 0; i < a.size(); ++i) { ExpectEqual(a(i), b(i)); } } static void Near(const T& a, const T& b, const double abs_err) { if (a != b) { // Takes care of inf. EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b; } } static void Near(const Tensor& x, const Tensor& y, const double abs_err) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); auto a = x.flat(); auto b = y.flat(); for (int i = 0; i < a.size(); ++i) { Near(a(i), b(i), abs_err); } } }; } // namespace internal template void ExpectTensorEqual(const Tensor& x, const Tensor& y) { internal::Expector::Equal(x, y); } template void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { static_assert(internal::is_floating_point_type::value, "T is not a floating point types."); internal::Expector::Near(x, y, abs_err); } } // namespace test } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_