diff options
-rw-r--r-- | tensorflow/core/kernels/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/cholesky_op.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/matrix_band_part_op.cc | 167 | ||||
-rw-r--r-- | tensorflow/core/kernels/matrix_band_part_op.h | 53 | ||||
-rw-r--r-- | tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc | 80 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/matrix_band_part_op_test.py | 120 |
6 files changed, 318 insertions, 129 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9a1e96a131..8bcbe0dd41 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -665,7 +665,9 @@ tf_kernel_library( tf_kernel_library( name = "matrix_band_part_op", prefix = "matrix_band_part_op", - deps = ARRAY_DEPS, + deps = if_cuda([ + ":cuda_solvers", + ]) + ARRAY_DEPS, ) tf_kernel_library( diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc index 755ce7c43b..6668b0d654 100644 --- a/tensorflow/core/kernels/cholesky_op.cc +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -76,18 +76,19 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> { typedef Eigen::GpuDevice GPUDevice; namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void MatrixBandPart<GPUDevice, T>::Compute( \ - const GPUDevice& d, Eigen::DenseIndex num_lower, \ - Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \ - typename TTypes<T, 3>::Tensor output); \ - extern template struct MatrixBandPart<GPUDevice, T>; +#define DECLARE_GPU_SPEC(T) \ + template <> \ + struct MatrixBandPartFunctor<GPUDevice, T> { \ + void operator()(OpKernelContext* context, const GPUDevice& device, \ + int num_upper_diags, int num_lower_diags, bool transpose, \ + typename TTypes<T, 3>::ConstTensor input, \ + typename TTypes<T, 3>::Tensor output); \ + }; \ + extern template struct MatrixBandPartFunctor<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_complex64(DECLARE_GPU_SPEC); TF_CALL_complex128(DECLARE_GPU_SPEC); - } // namespace functor template <class Scalar> @@ -131,9 +132,9 @@ class CholeskyOpGpu : public AsyncOpKernel { // before we launch each of the Cholesky factorization kernels in paralle. auto input_reshaped = input.template flat_inner_dims<Scalar, 3>(); auto output_reshaped = output->template flat_inner_dims<Scalar, 3>(); - functor::MatrixBandPart<GPUDevice, Scalar>::Compute( - context->eigen_device<GPUDevice>(), n, 0, input_reshaped, - output_reshaped); + functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn; + fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */, + input_reshaped, output_reshaped); // Launch a Cholesky kernel for each matrix in the batch. const int64 batch_size = input_reshaped.dimension(0); diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc index 894b0113c2..8b8accc0b3 100644 --- a/tensorflow/core/kernels/matrix_band_part_op.cc +++ b/tensorflow/core/kernels/matrix_band_part_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/matrix_band_part_op.h" +#include <algorithm> #include <memory> #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -48,31 +50,50 @@ class MatrixBandPartOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); + const TensorShape& input_shape = input.shape(); + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input.shape().DebugString())); + auto input_reshaped = input.flat_inner_dims<T, 3>(); + const Tensor& num_lower_in = context->input(1); OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()), errors::InvalidArgument("num_lower must be scalar, got shape ", num_lower_in.shape().DebugString())); const int64 num_lower = num_lower_in.scalar<int64>()(); + OP_REQUIRES( + context, num_lower <= input_reshaped.dimension(1), + errors::InvalidArgument( + "num_lower must be negative or less or equal to number of rows (", + input_reshaped.dimension(1), ") got: ", num_lower)); const Tensor& num_upper_in = context->input(2); OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()), errors::InvalidArgument("num_upper must be scalar, got shape ", num_upper_in.shape().DebugString())); const int64 num_upper = num_upper_in.scalar<int64>()(); + OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2), + errors::InvalidArgument("num_upper must be negative or less or " + "equal to number of columns (", + input_reshaped.dimension(2), + ") got: ", num_upper)); + + if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) && + (num_upper < 0 || num_upper == input_reshaped.dimension(2))) { + // This is a no-op. + context->set_output(0, input); + return; + } - const TensorShape& input_shape = input.shape(); - // Preliminary validation of sizes. - OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), - errors::InvalidArgument( - "input must be at least 2-dim, received shape: ", - input.shape().DebugString())); - auto input_reshaped = input.flat_inner_dims<T, 3>(); Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output)); + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input_shape, &output)); auto output_reshaped = output->flat_inner_dims<T, 3>(); - functor::MatrixBandPart<Device, T>::Compute( - context->eigen_device<Device>(), num_lower, num_upper, input_reshaped, - output_reshaped); + functor::MatrixBandPartFunctor<Device, T> fn; + fn(context, context->eigen_device<Device>(), num_lower, num_upper, + false /* transpose */, input_reshaped, output_reshaped); } private: @@ -98,54 +119,118 @@ TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART); // Implementation of the functor specialization for CPU. namespace functor { -template <typename T> -struct MatrixBandPart<CPUDevice, T> { - static void Compute(const CPUDevice& d, int64 num_lower, int64 num_upper, - typename TTypes<T, 3>::ConstTensor input, - typename TTypes<T, 3>::Tensor output) { - if ((num_lower < 0 || num_lower >= input.dimension(1)) && - (num_upper < 0 || num_upper >= input.dimension(2))) { - output.device(d) = input; + +// CPU implementation of BandPartFunctor. +typedef Eigen::ThreadPoolDevice CPUDevice; + +template <typename Scalar> +struct MatrixBandPartFunctor<CPUDevice, Scalar> { + void operator()(OpKernelContext* context, const CPUDevice& device, + int num_lower_diags, int num_upper_diags, bool transpose, + typename TTypes<Scalar, 3>::ConstTensor input, + typename TTypes<Scalar, 3>::Tensor output) { + const int64 b = input.dimension(0); + const int64 m = input.dimension(1); + const int64 n = input.dimension(2); + auto thread_pool = + context->device()->tensorflow_cpu_worker_threads()->workers; + const int64 total_rows = b * m; + const int64 row_cost = 10 * n; + const bool in_place = input.data() == output.data(); + CHECK(!(transpose && in_place)); + if (!transpose) { + auto compute_shard = [=, &input, &output](int64 begin, int64 end) { + if (!in_place) { + std::fill(output.data() + begin * n, output.data() + end * n, + Scalar()); + } + const int64 batch_begin = begin / m; + const int64 batch_end = (end + m - 1) / m; + for (int64 batch = batch_begin; batch < batch_end; ++batch) { + const int64 row_begin = begin > batch * m ? begin % m : 0; + const int64 row_end = end < (batch + 1) * m ? end % m : m; + for (int64 row = row_begin; row < row_end; ++row) { + const int64 band_start = + num_lower_diags < 0 + ? 0 + : std::min(n, std::max(0ll, row - num_lower_diags)); + const int64 band_end = num_upper_diags < 0 + ? n + : std::min(static_cast<int64>(n), + row + num_upper_diags + 1); + if (in_place) { + if (band_start > 0) { + std::fill(&output(batch, row, 0), + &output(batch, row, band_start), Scalar()); + } + if (band_end < n) { + std::fill(&output(batch, row, band_end), &output(batch, row, n), + Scalar()); + } + } else { + if (band_start < band_end) { + const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row, + band_start); + const Eigen::DSizes<Eigen::DenseIndex, 3> sizes( + 1, 1, band_end - band_start); + output.slice(indices, sizes) = input.slice(indices, sizes); + } + } + } + } + }; + thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard)); } else { - output.device(d) = output.constant(T()); - for (int64 r = 0; r < output.dimension(0); ++r) { - for (int64 i = 0; i < output.dimension(1); ++i) { - const int64 band_start = - num_lower < 0 ? 0 : std::max(0ll, i - num_lower); - const int64 band_end = - num_upper < 0 ? output.dimension(2) - : std::min(static_cast<int64>(output.dimension(2)), - i + num_upper + 1); - if (band_start < band_end) { - const Eigen::DSizes<Eigen::DenseIndex, 3> indices(r, i, band_start); - const Eigen::DSizes<Eigen::DenseIndex, 3> sizes( - 1, 1, band_end - band_start); - output.slice(indices, sizes) = input.slice(indices, sizes); + output.device(device) = output.constant(Scalar()); + auto compute_shard = [=, &input, &output](int64 begin, int64 end) { + const int64 batch_begin = begin / m; + const int64 batch_end = (end + m - 1) / m; + for (int64 batch = batch_begin; batch < batch_end; ++batch) { + const int64 row_begin = begin > batch * m ? begin % m : 0; + const int64 row_end = end < (batch + 1) * m ? end % m : m; + for (int64 row = row_begin; row < row_end; ++row) { + const int64 band_start = + num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags); + const int64 band_end = num_upper_diags < 0 + ? n + : std::min(static_cast<int64>(n), + row + num_upper_diags + 1); + for (int64 col = band_start; col < band_end; ++col) { + output(batch, col, row) = input(batch, row, col); + } } } - } + }; + thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard)); } } }; +#define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>; +TF_CALL_POD_TYPES(DEFINE_CPU_SPEC); +#undef DEFINE_CPU_SPEC + } // namespace functor #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void MatrixBandPart<GPUDevice, T>::Compute( \ - const GPUDevice& d, Eigen::DenseIndex num_lower, \ - Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \ - typename TTypes<T, 3>::Tensor output); \ - extern template struct MatrixBandPart<GPUDevice, T>; +#define DECLARE_GPU_SPEC(T) \ + template <> \ + struct MatrixBandPartFunctor<GPUDevice, T> { \ + void operator()(OpKernelContext* context, const GPUDevice& device, \ + int num_upper_diags, int num_lower_diags, bool transpose, \ + typename TTypes<T, 3>::ConstTensor input, \ + typename TTypes<T, 3>::Tensor output); \ + }; \ + extern template struct MatrixBandPartFunctor<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_bool(DECLARE_GPU_SPEC); TF_CALL_complex64(DECLARE_GPU_SPEC); TF_CALL_complex128(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC } // namespace functor // Registration of the GPU implementations. diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h index b601255b25..43b6724dae 100644 --- a/tensorflow/core/kernels/matrix_band_part_op.h +++ b/tensorflow/core/kernels/matrix_band_part_op.h @@ -16,61 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_ #define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_ -// Generator definition for MatrixBandPartOp, must be compilable by nvcc. - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { - -namespace generator { - -template <typename T> -class MatrixBandPartGenerator { - public: - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE MatrixBandPartGenerator( - Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper, - typename TTypes<T, 3>::ConstTensor input) - : num_lower_(num_lower), num_upper_(num_upper), input_(input) {} - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T - operator()(const Eigen::array<Eigen::DenseIndex, 3>& coords) const { - return (((num_lower_ < 0 || coords[1] - coords[2] <= num_lower_) && - (num_upper_ < 0 || coords[2] - coords[1] <= num_upper_)) - ? input_(coords) - : T()); - } - - private: - const Eigen::DenseIndex num_lower_; - const Eigen::DenseIndex num_upper_; - typename TTypes<T, 3>::ConstTensor input_; -}; - -} // namespace generator - namespace functor { -template <typename Device, typename T> -struct MatrixBandPart { - EIGEN_ALWAYS_INLINE static void Compute( - const Device& d, Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper, - typename TTypes<T, 3>::ConstTensor input, - typename TTypes<T, 3>::Tensor output) { - if ((num_lower < 0 || num_lower >= input.dimension(1)) && - (num_upper < 0 || num_upper >= input.dimension(2))) { - output.device(d) = input; - } else { - generator::MatrixBandPartGenerator<T> generator(num_lower, num_upper, - input); - output.device(d) = output.generate(generator); - } - } +template <typename Device, typename Scalar> +struct MatrixBandPartFunctor { + void operator()(OpKernelContext* context, const Device& device, + int num_upper_diags, int num_lower_diags, bool transpose, + typename TTypes<Scalar, 3>::ConstTensor input, + typename TTypes<Scalar, 3>::Tensor output); }; } // namespace functor - } // namespace tensorflow #endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_ diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc index ccc10ebada..afebdacdca 100644 --- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc @@ -17,22 +17,92 @@ limitations under the License. #define EIGEN_USE_GPU +#include <complex> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/matrix_band_part_op.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { - +namespace functor { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPEC(T) \ - template class generator::MatrixBandPartGenerator<T>; \ - template struct functor::MatrixBandPart<GPUDevice, T>; +template <bool transpose, typename Scalar> +__global__ void MatrixBandPartKernel(const int num_threads, + const int batch_size, const int m, + const int n, const int num_lower_diags, + const int num_upper_diags, + const Scalar* input_ptr, + Scalar* output_ptr) { + if (!transpose) { + CUDA_1D_KERNEL_LOOP(index, num_threads) { + const int col = index % n; + const int row = (index / n) % m; + const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags); + const int band_end = + (num_upper_diags < 0 ? n : row + num_upper_diags + 1); + if (col < band_start || col >= band_end) { + output_ptr[index] = Scalar(); + } else { + output_ptr[index] = input_ptr[index]; + } + } + } else { + const int matrix_size = m * n; + CUDA_1D_KERNEL_LOOP(index, num_threads) { + const int col = index % n; + const int row = (index / n) % m; + const int batch = index / matrix_size; + const int transpose_index = batch * matrix_size + n * col + row; + const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags); + const int band_end = + (num_upper_diags < 0 ? n : row + num_upper_diags + 1); + if (col < band_start || col >= band_end) { + output_ptr[transpose_index] = Scalar(); + } else { + output_ptr[transpose_index] = input_ptr[index]; + } + } + } +} + +template <typename Scalar> +struct MatrixBandPartFunctor<GPUDevice, Scalar> { + void operator()(OpKernelContext* context, const GPUDevice& device, + int num_lower_diags, int num_upper_diags, bool transpose, + typename TTypes<Scalar, 3>::ConstTensor input, + typename TTypes<Scalar, 3>::Tensor output) { + using CudaType = typename CUDAComplexT<Scalar>::type; + const int batch_size = input.dimension(0); + const int m = input.dimension(1); + const int n = input.dimension(2); + const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data()); + CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data()); + CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); + if (transpose) { + MatrixBandPartKernel<true> + <<<config.block_count, config.thread_per_block, 0, device.stream()>>>( + config.virtual_thread_count, batch_size, m, n, num_lower_diags, + num_upper_diags, input_ptr, output_ptr); + } else { + MatrixBandPartKernel<false> + <<<config.block_count, config.thread_per_block, 0, device.stream()>>>( + config.virtual_thread_count, batch_size, m, n, num_lower_diags, + num_upper_diags, input_ptr, output_ptr); + } + } +}; + +#define DEFINE_GPU_SPEC(T) template struct MatrixBandPartFunctor<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); TF_CALL_bool(DEFINE_GPU_SPEC); TF_CALL_complex64(DEFINE_GPU_SPEC); TF_CALL_complex128(DEFINE_GPU_SPEC); -} // end namespace tensorflow +#undef DEFINE_GPU_SPEC +} // namespace functor +} // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py index e641d5511f..317b8dc05b 100644 --- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py @@ -19,13 +19,24 @@ from __future__ import print_function import numpy as np +from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradient_checker -from tensorflow.python.platform import test +from tensorflow.python.ops import variables +from tensorflow.python.platform import test as test_lib -class MatrixBandPartTest(test.TestCase): +def _AddTest(test, op_name, testcase_name, fn): + test_name = "_".join(["test", op_name, testcase_name]) + if hasattr(test, test_name): + raise RuntimeError("Test %s defined more than once" % test_name) + setattr(test, test_name, fn) + + +class MatrixBandPartTest(test_lib.TestCase): pass # Filled in below @@ -34,23 +45,23 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_): def Test(self): mat = np.ones(shape_).astype(dtype_) batch_mat = np.tile(mat, batch_shape_ + (1, 1)) - with self.test_session(use_gpu=True): - for lower in -1, 0, 1, shape_[-2] - 1: - for upper in -1, 0, 1, shape_[-1] - 1: - band_np = mat - if lower >= 0: - band_np = np.triu(band_np, -lower) - if upper >= 0: - band_np = np.tril(band_np, upper) - if batch_shape_ is not (): - band_np = np.tile(band_np, batch_shape + (1, 1)) + for lower in -1, 0, 1, shape_[-2] - 1: + for upper in -1, 0, 1, shape_[-1] - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape_ is not (): + band_np = np.tile(band_np, batch_shape_ + (1, 1)) + with self.test_session(use_gpu=False): band = array_ops.matrix_band_part(batch_mat, lower, upper) self.assertAllEqual(band_np, band.eval()) return Test -class MatrixBandPartGradTest(test.TestCase): +class MatrixBandPartGradTest(test_lib.TestCase): pass # Filled in below @@ -59,7 +70,7 @@ def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_): def Test(self): shape = batch_shape_ + shape_ x = constant_op.constant(np.random.rand(*shape), dtype=dtype_) - with self.test_session(use_gpu=True): + with self.test_session(use_gpu=False): for lower in -1, 0, 1, shape_[-2] - 1: for upper in -1, 0, 1, shape_[-1] - 1: y = array_ops.matrix_band_part(x, lower, upper) @@ -70,18 +81,77 @@ def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_): return Test -if __name__ == '__main__': - for dtype in ( - np.int32, np.int64, np.float32, np.float64, np.complex64, np.complex128): +class MatrixBandPartBenchmark(test_lib.Benchmark): + + shapes = [ + (10, 16, 16), + (10, 101, 101), + (10, 256, 256), + (10, 1000, 1000), + (10, 1024, 1024), + (10, 2048, 2048), + (10, 10, 4, 4), + (10, 10, 10, 10), + (10, 10, 16, 16), + (10, 10, 101, 101), + (10, 10, 256, 256), + (10, 10, 1000, 1000), + (10, 10, 1024, 1024), + (10, 10, 2048, 2048), + ] + + def benchmarkMatrixBandPartOp(self): + for shape_ in self.shapes: + for limits in (-1, -1), (-1, 0), (0, -1), (2, 2): + with ops.Graph().as_default(), \ + session.Session() as sess, \ + ops.device("/cpu:0"): + matrix = variables.Variable(array_ops.ones(shape_)) + band = array_ops.matrix_band_part(matrix, limits[0], limits[1]) + variables.global_variables_initializer().run() + self.run_op_benchmark( + sess, + control_flow_ops.group(band), + min_iters=10, + name="matrix_band_part_cpu_{shape}_{limits}".format( + shape=shape_, limits=limits)) + + if test_lib.is_gpu_available(True): + with ops.Graph().as_default(), \ + session.Session() as sess, \ + ops.device("/gpu:0"): + matrix = variables.Variable(array_ops.ones(shape_)) + band = array_ops.matrix_band_part(matrix, limits[0], limits[1]) + variables.global_variables_initializer().run() + self.run_op_benchmark( + sess, + control_flow_ops.group(band), + min_iters=10, + name="matrix_band_part_gpu_{shape}_{limits}".format( + shape=shape_, limits=limits)) + + +if __name__ == "__main__": + dtypes = (np.bool, np.int32, np.int64, np.float32, np.float64, np.complex64, + np.complex128) + for dtype in dtypes: for batch_shape in ((), (2,), (1, 3, 2)): for rows in 1, 2, 7: for cols in 1, 2, 7: shape = (rows, cols) - name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape))) - setattr(MatrixBandPartTest, 'testMatrixBandPart_' + name, - _GetMatrixBandPartTest(dtype, batch_shape, shape)) - if dtype == np.float32 or dtype == np.float64: - setattr(MatrixBandPartGradTest, 'testMatrixBandPartGrad_' + name, - _GetMatrixBandPartGradTest(dtype, batch_shape, shape)) - - test.main() + name = "%s_%s" % (dtype.__name__, + "_".join(map(str, batch_shape + shape))) + _AddTest(MatrixBandPartTest, "MatrixBandPart", name, + _GetMatrixBandPartTest(dtype, batch_shape, shape)) + + for dtype in (np.float32, np.float64): + for batch_shape in ((), (2,)): + for rows in 1, 2, 7: + for cols in 1, 2, 7: + shape = (rows, cols) + name = "%s_%s" % (dtype.__name__, + "_".join(map(str, batch_shape + shape))) + _AddTest(MatrixBandPartGradTest, "MatrixBandPartGrad", name, + _GetMatrixBandPartGradTest(dtype, batch_shape, shape)) + + test_lib.main() |