aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc28
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h (renamed from tensorflow/core/kernels/batch_matmul_op.cc)23
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc33
3 files changed, 63 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
new file mode 100644
index 0000000000..a58ec02726
--- /dev/null
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -0,0 +1,28 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
+
+namespace tensorflow {
+
+TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
+TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
+
+#if GOOGLE_CUDA
+TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
+TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op_impl.h
index f8c12927e3..f2b74fc8c2 100644
--- a/tensorflow/core/kernels/batch_matmul_op.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -433,33 +433,14 @@ class BatchMatMul : public OpKernel {
bool adj_y_;
};
-#define REGISTER_CPU(TYPE) \
+#define REGISTER_BATCH_MATMUL_CPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMul<CPUDevice, TYPE>)
-#define REGISTER_GPU(TYPE) \
+#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMul<GPUDevice, TYPE>)
-TF_CALL_float(REGISTER_CPU);
-TF_CALL_double(REGISTER_CPU);
-TF_CALL_half(REGISTER_CPU);
-TF_CALL_int32(REGISTER_CPU);
-TF_CALL_complex64(REGISTER_CPU);
-TF_CALL_complex128(REGISTER_CPU);
-
-#if GOOGLE_CUDA
-TF_CALL_float(REGISTER_GPU);
-TF_CALL_double(REGISTER_GPU);
-TF_CALL_complex64(REGISTER_GPU);
-TF_CALL_complex128(REGISTER_GPU);
-#if CUDA_VERSION >= 7050
-TF_CALL_half(REGISTER_GPU);
-#endif
-#endif // GOOGLE_CUDA
-
-#undef REGISTER_CPU
-#undef REGISTER_GPU
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
new file mode 100644
index 0000000000..c719e30c4d
--- /dev/null
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -0,0 +1,33 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
+
+namespace tensorflow {
+
+TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
+TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
+TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
+TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
+
+#if GOOGLE_CUDA
+TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
+TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
+#if CUDA_VERSION >= 7050
+TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
+#endif
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow