aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 11:59:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 13:08:49 -0700
commitf6157b1d151ffcf758a9e0299cb6b2225e244229 (patch)
tree65f6b7484f5b6a99b0764a8000a86f190b763970 /tensorflow/contrib/rnn/kernels
parentc6f15e7a5469895e49e1da675eaec714b4dc0cce (diff)
Trivial change in tf.contrib.rnn: Remove code duplication
Change: 134697206
Diffstat (limited to 'tensorflow/contrib/rnn/kernels')
-rw-r--r--tensorflow/contrib/rnn/kernels/blas_gemm.cc69
-rw-r--r--tensorflow/contrib/rnn/kernels/blas_gemm.h86
-rw-r--r--tensorflow/contrib/rnn/kernels/gru_ops.cc46
-rw-r--r--tensorflow/contrib/rnn/kernels/gru_ops.h50
-rw-r--r--tensorflow/contrib/rnn/kernels/lstm_ops.cc45
-rw-r--r--tensorflow/contrib/rnn/kernels/lstm_ops.h49
6 files changed, 157 insertions, 188 deletions
diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.cc b/tensorflow/contrib/rnn/kernels/blas_gemm.cc
new file mode 100644
index 0000000000..637b872dad
--- /dev/null
+++ b/tensorflow/contrib/rnn/kernels/blas_gemm.cc
@@ -0,0 +1,69 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
+#include "tensorflow/core/framework/op_kernel.h"
+namespace tensorflow {
+
+#if GOOGLE_CUDA
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+} // namespace
+#endif // GOOGLE_CUDA
+
+namespace functor {
+template <typename T>
+void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
+ perftools::gputools::Stream* stream,
+ bool transa, bool transb, uint64 m,
+ uint64 n, uint64 k, T alpha, const T* a,
+ int lda, const T* b, int ldb, T beta, T* c,
+ int ldc) {
+#if GOOGLE_CUDA
+ perftools::gputools::blas::Transpose trans[] = {
+ perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose};
+
+ auto a_ptr = AsDeviceMemory(a);
+ auto b_ptr = AsDeviceMemory(b);
+ auto c_ptr = AsDeviceMemory(c);
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
+ lda, b_ptr, ldb, beta, &c_ptr, ldc)
+ .ok();
+ OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!"));
+#else
+ ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA."));
+#endif
+}
+
+template struct TensorCuBlasGemm<float>;
+template struct TensorCuBlasGemm<double>;
+
+} // end namespace functor
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h
new file mode 100644
index 0000000000..9c34b8ae71
--- /dev/null
+++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h
@@ -0,0 +1,86 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_activations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace perftools {
+namespace gputools {
+class Stream;
+} // end namespace gputools
+} // end namespace perftools
+
+namespace tensorflow {
+class OpKernelContext;
+namespace functor {
+
+template <typename T>
+struct TensorCuBlasGemm {
+ void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+ bool transa, bool transb, uint64 m, uint64 n, uint64 k,
+ T alpha, const T* a, int lda, const T* b, int ldb, T beta,
+ T* c, int ldc);
+};
+
+template <typename Device, typename T, bool USE_CUBLAS>
+struct TensorBlasGemm;
+
+template <typename Device, typename T>
+struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
+ static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+ const Device& d, bool transa, bool transb, T alpha,
+ typename TTypes<T>::ConstMatrix a,
+ typename TTypes<T>::ConstMatrix b, T beta,
+ typename TTypes<T>::Matrix c) {
+ int64 m = c.dimensions()[0];
+ int64 n = c.dimensions()[1];
+ int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
+
+ TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
+ transb ? k : n, a.data(), transa ? m : k, beta,
+ c.data(), n);
+ }
+};
+
+template <typename Device, typename T>
+struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
+ static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+ const Device& d, bool transa, bool transb, T alpha,
+ typename TTypes<T>::ConstMatrix a,
+ typename TTypes<T>::ConstMatrix b, T beta,
+ typename TTypes<T>::Matrix c) {
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
+ contract_pairs[0] =
+ Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true);
+ if (alpha == T(1) && beta == T(0)) {
+ c.device(d) = a.contract(b, contract_pairs);
+ } else if (alpha == T(1) && beta == T(1)) {
+ c.device(d) += a.contract(b, contract_pairs);
+ } else {
+ c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) +
+ c.constant(beta) * c;
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.cc b/tensorflow/contrib/rnn/kernels/gru_ops.cc
index 89a4817111..ae25322a40 100644
--- a/tensorflow/contrib/rnn/kernels/gru_ops.cc
+++ b/tensorflow/contrib/rnn/kernels/gru_ops.cc
@@ -28,52 +28,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-#if GOOGLE_CUDA
-namespace {
-template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
- return typed;
-}
-} // namespace
-#endif // GOOGLE_CUDA
-
-namespace functor {
-template <typename T>
-// TODO(gitegaurav) : Refactor the matmul operation inside the kernel. Make
-// similar changes in the LSTMBlockCell. Create a new file which contains matmul
-// functionality. It should perform matmul operation using CuBlas when the Cuda
-// support is present otherwise using eigentensors.
-void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
- perftools::gputools::Stream* stream,
- bool transa, bool transb, uint64 m,
- uint64 n, uint64 k, T alpha, const T* a,
- int lda, const T* b, int ldb, T beta, T* c,
- int ldc) {
-#if GOOGLE_CUDA
- perftools::gputools::blas::Transpose trans[] = {
- perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose};
-
- auto a_ptr = AsDeviceMemory(a);
- auto b_ptr = AsDeviceMemory(b);
- auto c_ptr = AsDeviceMemory(c);
-
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
- lda, b_ptr, ldb, beta, &c_ptr, ldc)
- .ok();
- OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!"));
-#else
- ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA."));
-#endif
-}
-
-template struct TensorCuBlasGemm<float>;
-} // end namespace functor
-
template <typename Device, typename T, bool USE_CUBLAS>
class GRUCellBlockOp : public OpKernel {
public:
diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.h b/tensorflow/contrib/rnn/kernels/gru_ops.h
index b2ff859b8e..e6c4ad9a03 100644
--- a/tensorflow/contrib/rnn/kernels/gru_ops.h
+++ b/tensorflow/contrib/rnn/kernels/gru_ops.h
@@ -17,6 +17,7 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
@@ -32,55 +33,6 @@ class OpKernelContext;
namespace functor {
-template <typename T>
-struct TensorCuBlasGemm {
- void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- bool transa, bool transb, uint64 m, uint64 n, uint64 k,
- T alpha, const T* a, int lda, const T* b, int ldb, T beta,
- T* c, int ldc);
-};
-
-template <typename Device, typename T, bool USE_CUBLAS>
-struct TensorBlasGemm;
-
-template <typename Device, typename T>
-struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
- static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- const Device& d, bool transa, bool transb, T alpha,
- typename TTypes<T>::ConstMatrix a,
- typename TTypes<T>::ConstMatrix b, T beta,
- typename TTypes<T>::Matrix c) {
- int64 m = c.dimensions()[0];
- int64 n = c.dimensions()[1];
- int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
-
- TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
- transb ? k : n, a.data(), transa ? m : k, beta,
- c.data(), n);
- }
-};
-
-template <typename Device, typename T>
-struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
- static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- const Device& d, bool transa, bool transb, T alpha,
- typename TTypes<T>::ConstMatrix a,
- typename TTypes<T>::ConstMatrix b, T beta,
- typename TTypes<T>::Matrix c) {
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
- contract_pairs[0] =
- Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true);
- if (alpha == T(1) && beta == T(0)) {
- c.device(d) = a.contract(b, contract_pairs);
- } else if (alpha == T(1) && beta == T(1)) {
- c.device(d) += a.contract(b, contract_pairs);
- } else {
- c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) +
- c.constant(beta) * c;
- }
- }
-};
-
struct GRUCell {
GRUCell(const int batch_size, const int input_size, const int cell_size)
: batch_size_(batch_size),
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc
index 7bfc119e2c..2749d7d797 100644
--- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc
@@ -43,51 +43,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-#if GOOGLE_CUDA
-
-namespace {
-template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
- return typed;
-}
-} // namespace
-
-#endif // GOOGLE_CUDA
-
-namespace functor {
-template <typename T>
-void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
- perftools::gputools::Stream* stream,
- bool transa, bool transb, uint64 m,
- uint64 n, uint64 k, T alpha, const T* a,
- int lda, const T* b, int ldb, T beta, T* c,
- int ldc) {
-#if GOOGLE_CUDA
- perftools::gputools::blas::Transpose trans[] = {
- perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose};
-
- auto a_ptr = AsDeviceMemory(a);
- auto b_ptr = AsDeviceMemory(b);
- auto c_ptr = AsDeviceMemory(c);
-
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
- lda, b_ptr, ldb, beta, &c_ptr, ldc)
- .ok();
- OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!"));
-#else
- ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA."));
-#endif
-}
-
-template struct TensorCuBlasGemm<float>;
-// template struct TensorCuBlasGemm<double>;
-} // end namespace functor
-
template <typename Device, typename T, bool USE_CUBLAS>
class LSTMBlockCellOp : public OpKernel {
public:
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h
index 15bc5d89d2..5a9dda5755 100644
--- a/tensorflow/contrib/rnn/kernels/lstm_ops.h
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h
@@ -17,6 +17,7 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
@@ -74,54 +75,6 @@ struct TensorZeroPadding {
}
};
-template <typename T>
-struct TensorCuBlasGemm {
- void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- bool transa, bool transb, uint64 m, uint64 n, uint64 k,
- T alpha, const T* a, int lda, const T* b, int ldb, T beta,
- T* c, int ldc);
-};
-
-template <typename Device, typename T, bool USE_CUBLAS>
-struct TensorBlasGemm;
-
-template <typename Device, typename T>
-struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
- static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- const Device& d, bool transa, bool transb, T alpha,
- typename TTypes<T>::ConstMatrix a,
- typename TTypes<T>::ConstMatrix b, T beta,
- typename TTypes<T>::Matrix c) {
- int64 m = c.dimensions()[0];
- int64 n = c.dimensions()[1];
- int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
-
- TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
- transb ? k : n, a.data(), transa ? m : k, beta,
- c.data(), n);
- }
-};
-
-template <typename Device, typename T>
-struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
- static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- const Device& d, bool transa, bool transb, T alpha,
- typename TTypes<T>::ConstMatrix a,
- typename TTypes<T>::ConstMatrix b, T beta,
- typename TTypes<T>::Matrix c) {
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
- contract_pairs[0] =
- Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true);
- if (alpha == T(1) && beta == T(0)) {
- c.device(d) = a.contract(b, contract_pairs);
- } else if (alpha == T(1) && beta == T(1)) {
- c.device(d) += a.contract(b, contract_pairs);
- } else {
- c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) +
- c.constant(beta) * c;
- }
- }
-};
struct LSTMBlockCell {
LSTMBlockCell(const int batch_size, const int input_size, const int cell_size)