diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-29 11:59:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-29 13:08:49 -0700 |
commit | f6157b1d151ffcf758a9e0299cb6b2225e244229 (patch) | |
tree | 65f6b7484f5b6a99b0764a8000a86f190b763970 /tensorflow/contrib/rnn/kernels | |
parent | c6f15e7a5469895e49e1da675eaec714b4dc0cce (diff) |
Trivial change in tf.contrib.rnn: Remove code duplication
Change: 134697206
Diffstat (limited to 'tensorflow/contrib/rnn/kernels')
-rw-r--r-- | tensorflow/contrib/rnn/kernels/blas_gemm.cc | 69 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/blas_gemm.h | 86 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/gru_ops.cc | 46 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/gru_ops.h | 50 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/lstm_ops.cc | 45 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/lstm_ops.h | 49 |
6 files changed, 157 insertions, 188 deletions
diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.cc b/tensorflow/contrib/rnn/kernels/blas_gemm.cc new file mode 100644 index 0000000000..637b872dad --- /dev/null +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.cc @@ -0,0 +1,69 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA + +#include "tensorflow/contrib/rnn/kernels/blas_gemm.h" +#include "tensorflow/core/framework/op_kernel.h" +namespace tensorflow { + +#if GOOGLE_CUDA +namespace { +template <typename T> +perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); + perftools::gputools::DeviceMemory<T> typed(wrapped); + return typed; +} +} // namespace +#endif // GOOGLE_CUDA + +namespace functor { +template <typename T> +void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx, + perftools::gputools::Stream* stream, + bool transa, bool transb, uint64 m, + uint64 n, uint64 k, T alpha, const T* a, + int lda, const T* b, int ldb, T beta, T* c, + int ldc) { +#if GOOGLE_CUDA + perftools::gputools::blas::Transpose trans[] = { + perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose}; + + auto a_ptr = AsDeviceMemory(a); + auto b_ptr = AsDeviceMemory(b); + auto c_ptr = AsDeviceMemory(c); + + bool blas_launch_status = + stream + ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr, + lda, b_ptr, ldb, beta, &c_ptr, ldc) + .ok(); + OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!")); +#else + ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); +#endif +} + +template struct TensorCuBlasGemm<float>; +template struct TensorCuBlasGemm<double>; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h new file mode 100644 index 0000000000..9c34b8ae71 --- /dev/null +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -0,0 +1,86 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_activations.h" +#include "tensorflow/core/platform/types.h" + +namespace perftools { +namespace gputools { +class Stream; +} // end namespace gputools +} // end namespace perftools + +namespace tensorflow { +class OpKernelContext; +namespace functor { + +template <typename T> +struct TensorCuBlasGemm { + void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream, + bool transa, bool transb, uint64 m, uint64 n, uint64 k, + T alpha, const T* a, int lda, const T* b, int ldb, T beta, + T* c, int ldc); +}; + +template <typename Device, typename T, bool USE_CUBLAS> +struct TensorBlasGemm; + +template <typename Device, typename T> +struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> { + static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, bool transa, bool transb, T alpha, + typename TTypes<T>::ConstMatrix a, + typename TTypes<T>::ConstMatrix b, T beta, + typename TTypes<T>::Matrix c) { + int64 m = c.dimensions()[0]; + int64 n = c.dimensions()[1]; + int64 k = transa ? a.dimensions()[0] : a.dimensions()[1]; + + TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(), + transb ? k : n, a.data(), transa ? m : k, beta, + c.data(), n); + } +}; + +template <typename Device, typename T> +struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> { + static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, bool transa, bool transb, T alpha, + typename TTypes<T>::ConstMatrix a, + typename TTypes<T>::ConstMatrix b, T beta, + typename TTypes<T>::Matrix c) { + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; + contract_pairs[0] = + Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true); + if (alpha == T(1) && beta == T(0)) { + c.device(d) = a.contract(b, contract_pairs); + } else if (alpha == T(1) && beta == T(1)) { + c.device(d) += a.contract(b, contract_pairs); + } else { + c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) + + c.constant(beta) * c; + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.cc b/tensorflow/contrib/rnn/kernels/gru_ops.cc index 89a4817111..ae25322a40 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.cc +++ b/tensorflow/contrib/rnn/kernels/gru_ops.cc @@ -28,52 +28,6 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -#if GOOGLE_CUDA -namespace { -template <typename T> -perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { - perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); - perftools::gputools::DeviceMemory<T> typed(wrapped); - return typed; -} -} // namespace -#endif // GOOGLE_CUDA - -namespace functor { -template <typename T> -// TODO(gitegaurav) : Refactor the matmul operation inside the kernel. Make -// similar changes in the LSTMBlockCell. Create a new file which contains matmul -// functionality. It should perform matmul operation using CuBlas when the Cuda -// support is present otherwise using eigentensors. -void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx, - perftools::gputools::Stream* stream, - bool transa, bool transb, uint64 m, - uint64 n, uint64 k, T alpha, const T* a, - int lda, const T* b, int ldb, T beta, T* c, - int ldc) { -#if GOOGLE_CUDA - perftools::gputools::blas::Transpose trans[] = { - perftools::gputools::blas::Transpose::kNoTranspose, - perftools::gputools::blas::Transpose::kTranspose}; - - auto a_ptr = AsDeviceMemory(a); - auto b_ptr = AsDeviceMemory(b); - auto c_ptr = AsDeviceMemory(c); - - bool blas_launch_status = - stream - ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr, - lda, b_ptr, ldb, beta, &c_ptr, ldc) - .ok(); - OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!")); -#else - ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); -#endif -} - -template struct TensorCuBlasGemm<float>; -} // end namespace functor - template <typename Device, typename T, bool USE_CUBLAS> class GRUCellBlockOp : public OpKernel { public: diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.h b/tensorflow/contrib/rnn/kernels/gru_ops.h index b2ff859b8e..e6c4ad9a03 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.h +++ b/tensorflow/contrib/rnn/kernels/gru_ops.h @@ -17,6 +17,7 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/contrib/rnn/kernels/blas_gemm.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" @@ -32,55 +33,6 @@ class OpKernelContext; namespace functor { -template <typename T> -struct TensorCuBlasGemm { - void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream, - bool transa, bool transb, uint64 m, uint64 n, uint64 k, - T alpha, const T* a, int lda, const T* b, int ldb, T beta, - T* c, int ldc); -}; - -template <typename Device, typename T, bool USE_CUBLAS> -struct TensorBlasGemm; - -template <typename Device, typename T> -struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> { - static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, - const Device& d, bool transa, bool transb, T alpha, - typename TTypes<T>::ConstMatrix a, - typename TTypes<T>::ConstMatrix b, T beta, - typename TTypes<T>::Matrix c) { - int64 m = c.dimensions()[0]; - int64 n = c.dimensions()[1]; - int64 k = transa ? a.dimensions()[0] : a.dimensions()[1]; - - TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(), - transb ? k : n, a.data(), transa ? m : k, beta, - c.data(), n); - } -}; - -template <typename Device, typename T> -struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> { - static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, - const Device& d, bool transa, bool transb, T alpha, - typename TTypes<T>::ConstMatrix a, - typename TTypes<T>::ConstMatrix b, T beta, - typename TTypes<T>::Matrix c) { - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; - contract_pairs[0] = - Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true); - if (alpha == T(1) && beta == T(0)) { - c.device(d) = a.contract(b, contract_pairs); - } else if (alpha == T(1) && beta == T(1)) { - c.device(d) += a.contract(b, contract_pairs); - } else { - c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) + - c.constant(beta) * c; - } - } -}; - struct GRUCell { GRUCell(const int batch_size, const int input_size, const int cell_size) : batch_size_(batch_size), diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index 7bfc119e2c..2749d7d797 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -43,51 +43,6 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -#if GOOGLE_CUDA - -namespace { -template <typename T> -perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { - perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); - perftools::gputools::DeviceMemory<T> typed(wrapped); - return typed; -} -} // namespace - -#endif // GOOGLE_CUDA - -namespace functor { -template <typename T> -void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx, - perftools::gputools::Stream* stream, - bool transa, bool transb, uint64 m, - uint64 n, uint64 k, T alpha, const T* a, - int lda, const T* b, int ldb, T beta, T* c, - int ldc) { -#if GOOGLE_CUDA - perftools::gputools::blas::Transpose trans[] = { - perftools::gputools::blas::Transpose::kNoTranspose, - perftools::gputools::blas::Transpose::kTranspose}; - - auto a_ptr = AsDeviceMemory(a); - auto b_ptr = AsDeviceMemory(b); - auto c_ptr = AsDeviceMemory(c); - - bool blas_launch_status = - stream - ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr, - lda, b_ptr, ldb, beta, &c_ptr, ldc) - .ok(); - OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!")); -#else - ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); -#endif -} - -template struct TensorCuBlasGemm<float>; -// template struct TensorCuBlasGemm<double>; -} // end namespace functor - template <typename Device, typename T, bool USE_CUBLAS> class LSTMBlockCellOp : public OpKernel { public: diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index 15bc5d89d2..5a9dda5755 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -17,6 +17,7 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/contrib/rnn/kernels/blas_gemm.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/types.h" @@ -74,54 +75,6 @@ struct TensorZeroPadding { } }; -template <typename T> -struct TensorCuBlasGemm { - void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream, - bool transa, bool transb, uint64 m, uint64 n, uint64 k, - T alpha, const T* a, int lda, const T* b, int ldb, T beta, - T* c, int ldc); -}; - -template <typename Device, typename T, bool USE_CUBLAS> -struct TensorBlasGemm; - -template <typename Device, typename T> -struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> { - static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, - const Device& d, bool transa, bool transb, T alpha, - typename TTypes<T>::ConstMatrix a, - typename TTypes<T>::ConstMatrix b, T beta, - typename TTypes<T>::Matrix c) { - int64 m = c.dimensions()[0]; - int64 n = c.dimensions()[1]; - int64 k = transa ? a.dimensions()[0] : a.dimensions()[1]; - - TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(), - transb ? k : n, a.data(), transa ? m : k, beta, - c.data(), n); - } -}; - -template <typename Device, typename T> -struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> { - static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, - const Device& d, bool transa, bool transb, T alpha, - typename TTypes<T>::ConstMatrix a, - typename TTypes<T>::ConstMatrix b, T beta, - typename TTypes<T>::Matrix c) { - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; - contract_pairs[0] = - Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true); - if (alpha == T(1) && beta == T(0)) { - c.device(d) = a.contract(b, contract_pairs); - } else if (alpha == T(1) && beta == T(1)) { - c.device(d) += a.contract(b, contract_pairs); - } else { - c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) + - c.constant(beta) * c; - } - } -}; struct LSTMBlockCell { LSTMBlockCell(const int batch_size, const int input_size, const int cell_size) |