aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_testutil.h
blob: 73afca40ac243c0305a138ad98f2485522d85c76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {
namespace test {

// Constructs a scalar tensor with 'val'.
template <typename T>
Tensor AsScalar(const T& val) {
  Tensor ret(DataTypeToEnum<T>::value, {});
  ret.scalar<T>()() = val;
  return ret;
}

// Constructs a flat tensor with 'vals'.
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals) {
  Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
  std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
  return ret;
}

// Constructs a tensor of "shape" with values "vals".
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
  Tensor ret;
  CHECK(ret.CopyFrom(AsTensor(vals), shape));
  return ret;
}

// Fills in '*tensor' with 'vals'. E.g.,
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillValues<float>(&x, {11, 21, 21, 22});
template <typename T>
void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
  auto flat = tensor->flat<T>();
  CHECK_EQ(flat.size(), vals.size());
  if (flat.size() > 0) {
    std::copy_n(vals.data(), vals.size(), flat.data());
  }
}

// Fills in '*tensor' with 'vals', converting the types as needed.
template <typename T, typename SrcType>
void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
  auto flat = tensor->flat<T>();
  CHECK_EQ(flat.size(), vals.size());
  if (flat.size() > 0) {
    size_t i = 0;
    for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
      flat(i) = T(*itr);
    }
  }
}

// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillIota<float>(&x, 1.0);
template <typename T>
void FillIota(Tensor* tensor, const T& val) {
  auto flat = tensor->flat<T>();
  std::iota(flat.data(), flat.data() + flat.size(), val);
}

// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillFn<float>(&x, [](int i)->float { return i*i; });
template <typename T>
void FillFn(Tensor* tensor, std::function<T(int)> fn) {
  auto flat = tensor->flat<T>();
  for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
}

// Expects "x" and "y" are tensors of the same type, same shape, and
// identical values.
template <typename T>
void ExpectTensorEqual(const Tensor& x, const Tensor& y);

// Expects "x" and "y" are tensors of the same type, same shape, and
// approximate equal values, each within "abs_err".
template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);

// Expects "x" and "y" are tensors of the same type (float or double),
// same shape and element-wise difference between x and y is no more
// than atol + rtol * abs(x).
void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6,
                 double rtol = 1e-6);

// Implementation details.

namespace internal {

template <typename T>
struct is_floating_point_type {
  static const bool value = std::is_same<T, Eigen::half>::value ||
                            std::is_same<T, float>::value ||
                            std::is_same<T, double>::value ||
                            std::is_same<T, std::complex<float> >::value ||
                            std::is_same<T, std::complex<double> >::value;
};

template <typename T>
inline void ExpectEqual(const T& a, const T& b) {
  EXPECT_EQ(a, b);
}

template <>
inline void ExpectEqual<float>(const float& a, const float& b) {
  EXPECT_FLOAT_EQ(a, b);
}

template <>
inline void ExpectEqual<double>(const double& a, const double& b) {
  EXPECT_DOUBLE_EQ(a, b);
}

template <>
inline void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
  EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
  EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
}

template <>
inline void ExpectEqual<complex128>(const complex128& a, const complex128& b) {
  EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b;
  EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b;
}

inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
  ASSERT_EQ(x.dtype(), y.dtype());
  ASSERT_TRUE(x.IsSameSize(y))
      << "x.shape [" << x.shape().DebugString() << "] vs "
      << "y.shape [ " << y.shape().DebugString() << "]";
}

template <typename T, bool is_fp = is_floating_point_type<T>::value>
struct Expector;

template <typename T>
struct Expector<T, false> {
  static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }

  static void Equal(const Tensor& x, const Tensor& y) {
    ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
    AssertSameTypeDims(x, y);
    auto a = x.flat<T>();
    auto b = y.flat<T>();
    for (int i = 0; i < a.size(); ++i) {
      ExpectEqual(a(i), b(i));
    }
  }
};

// Partial specialization for float and double.
template <typename T>
struct Expector<T, true> {
  static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }

  static void Equal(const Tensor& x, const Tensor& y) {
    ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
    AssertSameTypeDims(x, y);
    auto a = x.flat<T>();
    auto b = y.flat<T>();
    for (int i = 0; i < a.size(); ++i) {
      ExpectEqual(a(i), b(i));
    }
  }

  static void Near(const T& a, const T& b, const double abs_err) {
    if (a != b) {  // Takes care of inf.
      EXPECT_LE(double(Eigen::numext::abs(a - b)), abs_err) << "a = " << a
                                                            << " b = " << b;
    }
  }

  static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
    ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
    AssertSameTypeDims(x, y);
    auto a = x.flat<T>();
    auto b = y.flat<T>();
    for (int i = 0; i < a.size(); ++i) {
      Near(a(i), b(i), abs_err);
    }
  }
};

}  // namespace internal

template <typename T>
void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
  internal::Expector<T>::Equal(x, y);
}

template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
  static_assert(internal::is_floating_point_type<T>::value,
                "T is not a floating point types.");
  internal::Expector<T>::Near(x, y, abs_err);
}

}  // namespace test
}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_