aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_testutil.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor_testutil.h')
-rw-r--r--tensorflow/core/framework/tensor_testutil.h189
1 files changed, 189 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h
new file mode 100644
index 0000000000..53d6da0fb2
--- /dev/null
+++ b/tensorflow/core/framework/tensor_testutil.h
@@ -0,0 +1,189 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace test {
+
+// Constructs a scalar tensor with 'val'.
+template <typename T>
+Tensor AsScalar(const T& val) {
+ Tensor ret(DataTypeToEnum<T>::value, {});
+ ret.scalar<T>()() = val;
+ return ret;
+}
+
+// Constructs a flat tensor with 'vals'.
+template <typename T>
+Tensor AsTensor(gtl::ArraySlice<T> vals) {
+ Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
+ std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
+ return ret;
+}
+
+// Constructs a tensor of "shape" with values "vals".
+template <typename T>
+Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
+ Tensor ret;
+ CHECK(ret.CopyFrom(AsTensor(vals), shape));
+ return ret;
+}
+
+// Fills in '*tensor' with 'vals'. E.g.,
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillValues<float>(&x, {11, 21, 21, 22});
+template <typename T>
+void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
+ auto flat = tensor->flat<T>();
+ CHECK_EQ(flat.size(), vals.size());
+ if (flat.size() > 0) {
+ std::copy_n(vals.data(), vals.size(), flat.data());
+ }
+}
+
+// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillIota<float>(&x, 1.0);
+template <typename T>
+void FillIota(Tensor* tensor, const T& val) {
+ auto flat = tensor->flat<T>();
+ std::iota(flat.data(), flat.data() + flat.size(), val);
+}
+
+// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillFn<float>(&x, [](int i)->float { return i*i; });
+template <typename T>
+void FillFn(Tensor* tensor, std::function<T(int)> fn) {
+ auto flat = tensor->flat<T>();
+ for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
+}
+
+// Expects "x" and "y" are tensors of the same type, same shape, and
+// identical values.
+template <typename T>
+void ExpectTensorEqual(const Tensor& x, const Tensor& y);
+
+// Expects "x" and "y" are tensors of the same type, same shape, and
+// approxmiate equal values, each within "abs_err".
+template <typename T>
+void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
+
+// Expects "x" and "y" are tensors of the same type (float or double),
+// same shape and element-wise difference between x and y is no more
+// than atol + rtol * abs(x).
+void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6,
+ double rtol = 1e-6);
+
+// Implementation details.
+
+namespace internal {
+
+template <typename T>
+struct is_floating_point_type {
+ static const bool value = std::is_same<T, float>::value ||
+ std::is_same<T, double>::value ||
+ std::is_same<T, std::complex<float> >::value ||
+ std::is_same<T, std::complex<double> >::value;
+};
+
+template <typename T>
+static void ExpectEqual(const T& a, const T& b) {
+ EXPECT_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<float>(const float& a, const float& b) {
+ EXPECT_FLOAT_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<double>(const double& a, const double& b) {
+ EXPECT_DOUBLE_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
+ EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
+ EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
+}
+
+inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), y.dtype());
+ ASSERT_TRUE(x.IsSameSize(y))
+ << "x.shape [" << x.shape().DebugString() << "] vs "
+ << "y.shape [ " << y.shape().DebugString() << "]";
+}
+
+template <typename T, bool is_fp = is_floating_point_type<T>::value>
+struct Expector;
+
+template <typename T>
+struct Expector<T, false> {
+ static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
+
+ static void Equal(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ ExpectEqual(a(i), b(i));
+ }
+ }
+};
+
+// Partial specialization for float and double.
+template <typename T>
+struct Expector<T, true> {
+ static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
+
+ static void Equal(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ ExpectEqual(a(i), b(i));
+ }
+ }
+
+ static void Near(const T& a, const T& b, const double abs_err) {
+ if (a != b) { // Takes care of inf.
+ EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b;
+ }
+ }
+
+ static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ Near(a(i), b(i), abs_err);
+ }
+ }
+};
+
+} // namespace internal
+
+template <typename T>
+void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
+ internal::Expector<T>::Equal(x, y);
+}
+
+template <typename T>
+void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
+ static_assert(internal::is_floating_point_type<T>::value,
+ "T is not a floating point types.");
+ internal::Expector<T>::Near(x, y, abs_err);
+}
+
+} // namespace test
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_