diff options
-rw-r--r-- | tensorflow/core/kernels/batch_matmul_op_complex.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/kernels/batch_matmul_op_impl.h (renamed from tensorflow/core/kernels/batch_matmul_op.cc) | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/batch_matmul_op_real.cc | 33 |
3 files changed, 63 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc new file mode 100644 index 0000000000..a58ec02726 --- /dev/null +++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/batch_matmul_op_impl.h" + +namespace tensorflow { + +TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU); +TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU); + +#if GOOGLE_CUDA +TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU); +TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU); +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op_impl.h index f8c12927e3..f2b74fc8c2 100644 --- a/tensorflow/core/kernels/batch_matmul_op.cc +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -433,33 +433,14 @@ class BatchMatMul : public OpKernel { bool adj_y_; }; -#define REGISTER_CPU(TYPE) \ +#define REGISTER_BATCH_MATMUL_CPU(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ BatchMatMul<CPUDevice, TYPE>) -#define REGISTER_GPU(TYPE) \ +#define REGISTER_BATCH_MATMUL_GPU(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ BatchMatMul<GPUDevice, TYPE>) -TF_CALL_float(REGISTER_CPU); -TF_CALL_double(REGISTER_CPU); -TF_CALL_half(REGISTER_CPU); -TF_CALL_int32(REGISTER_CPU); -TF_CALL_complex64(REGISTER_CPU); -TF_CALL_complex128(REGISTER_CPU); - -#if GOOGLE_CUDA -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); -TF_CALL_complex64(REGISTER_GPU); -TF_CALL_complex128(REGISTER_GPU); -#if CUDA_VERSION >= 7050 -TF_CALL_half(REGISTER_GPU); -#endif -#endif // GOOGLE_CUDA - -#undef REGISTER_CPU -#undef REGISTER_GPU } // end namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc new file mode 100644 index 0000000000..c719e30c4d --- /dev/null +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/batch_matmul_op_impl.h" + +namespace tensorflow { + +TF_CALL_float(REGISTER_BATCH_MATMUL_CPU); +TF_CALL_double(REGISTER_BATCH_MATMUL_CPU); +TF_CALL_half(REGISTER_BATCH_MATMUL_CPU); +TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); + +#if GOOGLE_CUDA +TF_CALL_float(REGISTER_BATCH_MATMUL_GPU); +TF_CALL_double(REGISTER_BATCH_MATMUL_GPU); +#if CUDA_VERSION >= 7050 +TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); +#endif +#endif // GOOGLE_CUDA + +} // namespace tensorflow |