diff options
author | 2016-07-31 22:07:30 -0800 | |
---|---|---|
committer | 2016-07-31 23:17:46 -0700 | |
commit | abe9ab326625105adb3c9d46c027931aec947d1f (patch) | |
tree | d9fa7eb9a2fd9b37bc87f98cf353354391b9eb04 /tensorflow/core/kernels/cwise_ops_test.cc | |
parent | c0637048dbc099eac1f75878b765220cd02ccfc0 (diff) |
Merge changes from github.
Change: 128958134
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_test.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_test.cc | 39 |
1 files changed, 23 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc index 2cf51878ba..823e7e14ed 100644 --- a/tensorflow/core/kernels/cwise_ops_test.cc +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -23,13 +23,14 @@ limitations under the License. namespace tensorflow { -// Creates a Graph which applies a unary "func" on a 3D float tensor -// of "num" elements. -static Graph* Unary(const string& func, int num) { +// Creates a Graph which applies a unary "func" on a 3D tensor of +// type T with "num" elements. +template <typename T> +static Graph* Unary(const string& func, int num, DataType dtype) { Graph* g = new Graph(OpRegistry::Global()); - Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + Tensor data(dtype, TensorShape({64, 64, num / (64 * 64)})); CHECK_GT(data.NumElements(), 0); - data.flat<float>().setRandom(); + data.flat<T>().setRandom(); test::graph::Unary(g, func, test::graph::Constant(g, data), 0); return g; } @@ -40,17 +41,23 @@ static int RowsAndColsArg(int r, int c) { return r * kRows + c; } static int RowsFromArg(int arg) { return (arg / kRows); } static int ColsFromArg(int arg) { return (arg % kRows); } -#define BM_UNARY(DEVICE, FUNC) \ - static void BM_##DEVICE##_##FUNC(int iters, int num) { \ - const int64 tot = static_cast<int64>(iters) * num; \ - testing::ItemsProcessed(tot); \ - testing::BytesProcessed(tot * sizeof(float)); \ - test::Benchmark(#DEVICE, Unary(#FUNC, num)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_##FUNC)->Range(4 << 10, 1 << 20); - -BM_UNARY(cpu, Floor); -BM_UNARY(gpu, Floor); +#define BM_UNARY(DEVICE, FUNC, T, TYPE) \ + static void BM_##DEVICE##_##FUNC##_##TYPE(int iters, int num) { \ + const int64 tot = static_cast<int64>(iters) * num; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(T)); \ + test::Benchmark(#DEVICE, Unary<T>(#FUNC, num, TYPE)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_##FUNC##_##TYPE)->Range(4 << 10, 1 << 20); + +BM_UNARY(cpu, Floor, float, DT_FLOAT); +BM_UNARY(gpu, Floor, float, DT_FLOAT); +BM_UNARY(cpu, Floor, double, DT_DOUBLE); +BM_UNARY(gpu, Floor, double, DT_DOUBLE); +BM_UNARY(cpu, Conj, std::complex<float>, DT_COMPLEX64); +BM_UNARY(gpu, Conj, std::complex<float>, DT_COMPLEX64); +BM_UNARY(cpu, Conj, std::complex<double>, DT_COMPLEX128); +BM_UNARY(gpu, Conj, std::complex<double>, DT_COMPLEX128); // data func scalar. static Graph* BinaryScalar(int num, const string& func) { |