diff options
author | 2016-07-13 17:54:00 -0800 | |
---|---|---|
committer | 2016-07-13 19:03:24 -0700 | |
commit | 6c7681fbbcc3c244f3e406abc4ea1287fd717752 (patch) | |
tree | 9eef3d30467089a80b43abe6a3d59800705ed7f2 /tensorflow/core/kernels/matmul_op_test.cc | |
parent | 529e29712e681aefbf08539b6fae50fafdae8cc3 (diff) |
Enable complex GPU kernels for tf.matmul and tf.batch_matmul.
Change: 127386123
Diffstat (limited to 'tensorflow/core/kernels/matmul_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/matmul_op_test.cc | 39 |
1 files changed, 24 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index 0db6c2a6f1..1ba774c559 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -20,28 +20,37 @@ limitations under the License. namespace tensorflow { -static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b) { +template <typename T> +static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b, + DataType type) { Graph* g = new Graph(OpRegistry::Global()); - Tensor in0(DT_FLOAT, transpose_a ? TensorShape({k, m}) : TensorShape({m, k})); - in0.flat<float>().setRandom(); - Tensor in1(DT_FLOAT, transpose_b ? TensorShape({n, k}) : TensorShape({k, n})); - in1.flat<float>().setRandom(); + Tensor in0(type, transpose_a ? TensorShape({k, m}) : TensorShape({m, k})); + in0.flat<T>().setRandom(); + Tensor in1(type, transpose_b ? TensorShape({n, k}) : TensorShape({k, n})); + in1.flat<T>().setRandom(); test::graph::Matmul(g, test::graph::Constant(g, in0), test::graph::Constant(g, in1), transpose_a, transpose_b); return g; } -#define BM_MatmulDev(M, K, N, TA, TB, DEVICE) \ - static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE( \ - int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \ - test::Benchmark(#DEVICE, Matmul(M, K, N, TA, TB)).Run(iters); \ - } \ - BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE); +#define BM_MatmulDev(M, K, N, TA, TB, T, TFTYPE, DEVICE) \ + static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \ + int iters) { \ + testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \ + test::Benchmark(#DEVICE, Matmul<T>(M, K, N, TA, TB, TFTYPE)).Run(iters); \ + } \ + BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE); -#define BM_Matmul(M, K, N, TA, TB) \ - BM_MatmulDev(M, K, N, TA, TB, cpu); \ - BM_MatmulDev(M, K, N, TA, TB, gpu); +#define BM_Matmul(M, K, N, TA, TB) \ + BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, cpu); \ + BM_MatmulDev(M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, cpu); \ + BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, gpu); \ + BM_MatmulDev(M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, gpu); \ +/* Uncomment to enable benchmarks for double/complex128: */ \ +// BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ +// BM_MatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu); \ +// BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \ +// BM_MatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu); // Typical fully connected layers BM_Matmul(8, 512, 512, false, false); |