aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matmul_op_test.cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2016-07-13 17:54:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-13 19:03:24 -0700
commit6c7681fbbcc3c244f3e406abc4ea1287fd717752 (patch)
tree9eef3d30467089a80b43abe6a3d59800705ed7f2 /tensorflow/core/kernels/matmul_op_test.cc
parent529e29712e681aefbf08539b6fae50fafdae8cc3 (diff)
Enable complex GPU kernels for tf.matmul and tf.batch_matmul.
Change: 127386123
Diffstat (limited to 'tensorflow/core/kernels/matmul_op_test.cc')
-rw-r--r--tensorflow/core/kernels/matmul_op_test.cc39
1 files changed, 24 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc
index 0db6c2a6f1..1ba774c559 100644
--- a/tensorflow/core/kernels/matmul_op_test.cc
+++ b/tensorflow/core/kernels/matmul_op_test.cc
@@ -20,28 +20,37 @@ limitations under the License.
namespace tensorflow {
-static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b) {
+template <typename T>
+static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b,
+ DataType type) {
Graph* g = new Graph(OpRegistry::Global());
- Tensor in0(DT_FLOAT, transpose_a ? TensorShape({k, m}) : TensorShape({m, k}));
- in0.flat<float>().setRandom();
- Tensor in1(DT_FLOAT, transpose_b ? TensorShape({n, k}) : TensorShape({k, n}));
- in1.flat<float>().setRandom();
+ Tensor in0(type, transpose_a ? TensorShape({k, m}) : TensorShape({m, k}));
+ in0.flat<T>().setRandom();
+ Tensor in1(type, transpose_b ? TensorShape({n, k}) : TensorShape({k, n}));
+ in1.flat<T>().setRandom();
test::graph::Matmul(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1), transpose_a, transpose_b);
return g;
}
-#define BM_MatmulDev(M, K, N, TA, TB, DEVICE) \
- static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE( \
- int iters) { \
- testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
- test::Benchmark(#DEVICE, Matmul(M, K, N, TA, TB)).Run(iters); \
- } \
- BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE);
+#define BM_MatmulDev(M, K, N, TA, TB, T, TFTYPE, DEVICE) \
+ static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
+ int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
+ test::Benchmark(#DEVICE, Matmul<T>(M, K, N, TA, TB, TFTYPE)).Run(iters); \
+ } \
+ BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
-#define BM_Matmul(M, K, N, TA, TB) \
- BM_MatmulDev(M, K, N, TA, TB, cpu); \
- BM_MatmulDev(M, K, N, TA, TB, gpu);
+#define BM_Matmul(M, K, N, TA, TB) \
+ BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, cpu); \
+ BM_MatmulDev(M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, cpu); \
+ BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, gpu); \
+ BM_MatmulDev(M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, gpu); \
+/* Uncomment to enable benchmarks for double/complex128: */ \
+// BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
+// BM_MatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu); \
+// BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
+// BM_MatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Typical fully connected layers
BM_Matmul(8, 512, 512, false, false);