aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-10-06 09:35:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 09:38:26 -0700
commit2daa40f9d096d47fc3add05a36fb7e41a00ba69d (patch)
tree3372aa941a4135f1e060653226fec9d956108096 /tensorflow
parent3251bc07927c6a60916fc274e11445d42e5ec193 (diff)
Fix transpose bug for large dimension.
Add random tests of large shapes for better coverage. Update transpose benchmark with cases that swap one small dimension with one large dimension. PiperOrigin-RevId: 171302097
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc127
-rw-r--r--tensorflow/python/BUILD20
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py74
-rw-r--r--tensorflow/python/ops/conv2d_benchmark.py141
-rw-r--r--tensorflow/python/ops/transpose_benchmark.py48
5 files changed, 393 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index 3d4670c9ba..9083626fbf 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -272,6 +272,88 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
}
}
+// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
+// when only one of the dimension sizes is smaller than 16,
+// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
+//
+// small_dim = the_smaller_dimension_size
+// large_dim = the_larger_dimension_size
+// tile_num_per_block = blockDim.x
+// kTileLength = small_dim
+//
+// Each thread block operates on a single rectangle tile, where its width is
+// kTileLength (we currently set it to 64) and its height is small_dim,
+// We set the thread block's X dimension to be tile_num_per_block, and its Y
+// and Z to be one.
+template <typename T, int ShmemSize, bool SmallDim2>
+__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
+ int batch_per_block,
+ Dimension<3> input_dims,
+ T* output) {
+ // TODO(yangzihao) avoid share memory bank conflict.
+ __shared__ T shared_memory_tile[ShmemSize];
+
+ eigen_assert(blockDim.y == 1);
+ eigen_assert(blockDim.z == 1);
+ eigen_assert(gridDim.z == 1);
+
+ int block_offset = blockIdx.x * blockDim.x;
+
+ int x = threadIdx.x;
+ int tile_height = blockDim.x;
+
+ // Get tile height, width, and thread/block origin indices.
+ int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
+ int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
+
+ int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
+ (SmallDim2 ? block_offset * small_dim : block_offset);
+ if (global_offset >= (input_dims[0] * input_dims[1] * input_dims[2])) return;
+
+ for (int batch = 0; batch < batch_per_block; ++batch) {
+ int block_origin_idx =
+ small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
+ int thread_origin_idx =
+ block_origin_idx +
+ (SmallDim2 ? block_offset * small_dim : block_offset) + x;
+
+ if (block_offset + blockDim.x > large_dim) {
+ tile_height = large_dim - block_offset;
+ }
+
+ __syncthreads();
+
+ // Load a continuous memory region to shared memory tile.
+ if (x < tile_height) {
+ for (int y = 0; y < small_dim; y++) {
+ int shmem_index =
+ SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
+ shared_memory_tile[shmem_index] =
+ ldg(input + thread_origin_idx +
+ y * (SmallDim2 ? tile_height : large_dim));
+ }
+ }
+
+ __syncthreads();
+
+ // Get block origin index for output array.
+ int output_block_offset = block_origin_idx;
+ int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
+ int output_block_origin_idx = output_block_offset + output_block_idx;
+
+ // Store the tranposed memory region in shared memory to device.
+ if (x < tile_height) {
+ for (int y = 0; y < small_dim; y++) {
+ int output_idx = output_block_origin_idx + x +
+ y * (SmallDim2 ? large_dim : tile_height);
+ int shmem_index =
+ SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
+ output[output_idx] = shared_memory_tile[shmem_index];
+ }
+ }
+ }
+}
+
// A Cuda custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
@@ -420,25 +502,62 @@ template <typename T>
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
const Dimension<3>& input_dims, T* output) {
// If both dimensions are not trivial, use tiles for the actual swapping.
+ // If one dimension is trivial, use SmallDim kernel for swapping.
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
static const int kMinDimensionToUseTiles = 16;
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles);
+ bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
+ input_dims[2] < kMinDimensionToUseTiles)) ||
+ ((input_dims[1] < kMinDimensionToUseTiles &&
+ input_dims[2] >= kMinDimensionToUseTiles));
+ static const int NumSubTiles = 8;
+
if (use_tiles) {
- // We get best performance when TileSize is the number of threads in a warp
- // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
- // threads.
static const int TileSize = 32;
- static const int NumSubTiles = 8;
Dimension<3> input_dims_in_tiles = {
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
(input_dims[2] + TileSize - 1) / TileSize,
};
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
+ // We get best performance when TileSize is the number of threads in a warp
+ // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
+ // threads.
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
input, input_dims, output);
+ } else if (use_small_dim) {
+ // When only one of the dimensions is smaller than kMinDimensionToUseTiles,
+ // we use one block to process a rectangle region with the size of
+ // kTileLength * small_dim. We found that when set kTileLength to 64 on
+ // TitanX Maxwell GPU, it achieves the best performance.
+ // large_dim
+ // +---------------...--------+
+ // | | | |
+ // small_dim | | ... | |
+ // | | | |
+ // +--------------...---------+
+ // \----- ------/ \- -/
+ // V V
+ // kTileLength(tile_height) tile_height
+ static const int kTileLength = 64;
+ static const int kGridDimY = 65535;
+ int large_dim = std::max(input_dims[2], input_dims[1]);
+ int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
+ int grid_dim_y = std::min(input_dims[0], kGridDimY);
+ int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
+ if (input_dims[2] < input_dims[1]) {
+ SwapDimension1And2InTensor3SmallDim<
+ T, kTileLength * kMinDimensionToUseTiles, true>
+ <<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
+ d.stream()>>>(input, batch_per_block, input_dims, output);
+ } else {
+ SwapDimension1And2InTensor3SmallDim<
+ T, kTileLength * kMinDimensionToUseTiles, false>
+ <<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
+ d.stream()>>>(input, batch_per_block, input_dims, output);
+ }
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ab3b851ef8..bdbad14660 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4061,6 +4061,26 @@ cuda_py_test(
)
cuda_py_test(
+ name = "conv2d_benchmark",
+ size = "large",
+ srcs = ["ops/conv2d_benchmark.py"],
+ additional_deps = [
+ ":client",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":nn_ops",
+ ":platform",
+ ":platform_benchmark",
+ ":random_ops",
+ ":variables",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ ],
+ main = "ops/conv2d_benchmark.py",
+)
+
+cuda_py_test(
name = "split_benchmark",
srcs = ["ops/split_benchmark.py"],
additional_deps = [
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 570fa79944..9e1f83395b 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -229,6 +229,80 @@ class TransposeTest(test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
+ def testLargeSizeGPU(self):
+ # If no GPU available, skip the test
+ if not test.is_gpu_available(cuda_only=True):
+ return
+
+ large_shapes = [[1000000, 31, 3], [3, 1000000, 31], [3, 31, 1000000],
+ [10000, 310, 3], [3, 10000, 310], [3, 310, 10000],
+ [2, 1000, 1000], [1000, 2, 1000], [1000, 1000, 2]]
+ perms = [[0, 2, 1]] * 9
+
+ for input_shape, perm in zip(large_shapes, perms):
+ total_size = np.prod(input_shape)
+ inp = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_shape)
+ np_ans = self._np_transpose(inp, perm)
+ with self.test_session(use_gpu=True):
+ inx = ops.convert_to_tensor(inp)
+ y = array_ops.transpose(inx, perm)
+ tf_ans = y.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, y)
+
+ def testRandomizedSmallDimLargeSizeGPU(self):
+ # If no GPU available, skip the test
+ if not test.is_gpu_available(cuda_only=True):
+ return
+
+ # Draw 10 random shapes with large dimension sizes.
+ # 40% prob to generate dim[0] size within [1, 2047]
+ # 40% prob to generate dim[0] size within [2048, 4095]
+ # 20% prob to generate dim[0] size within [4096, 100000]
+ # 50% prob to use dim[1] as the small dim (<16)
+ num_samples = 10
+ total_size = 500000
+ small_size_limit = 2048
+ large_size_limit = 95905
+ small_size_percentage = 0.4
+ medium_size_percentage = 0.4
+ large_size_percentage = 0.2
+ perms = [[0, 2, 1]] * num_samples
+ dim_zero_sizes = []
+ dim_zero_sizes += list(
+ np.random.randint(
+ small_size_limit, size=int(small_size_percentage * num_samples)) +
+ 1)
+ dim_zero_sizes += list(
+ np.random.randint(
+ small_size_limit, size=int(medium_size_percentage * num_samples)) +
+ small_size_limit)
+ dim_zero_sizes += list(
+ np.random.randint(
+ large_size_limit, size=int(large_size_percentage * num_samples)) +
+ small_size_limit * 2)
+ input_shapes = []
+ small_dim_limit = 16
+ for dim_zero_size in dim_zero_sizes:
+ small_dim_size = np.random.randint(small_dim_limit - 1) + 1
+ large_dim_size = int(
+ total_size / dim_zero_size / small_dim_size) + small_dim_limit
+ input_shapes += ([[dim_zero_size, small_dim_size, large_dim_size]]
+ if np.random.randint(2) else
+ [[dim_zero_size, large_dim_size, small_dim_size]])
+
+ for input_shape, perm in zip(input_shapes, perms):
+ # generate input data with random ints from 0 to 9.
+ inp = np.random.randint(10, size=input_shape)
+ np_ans = self._np_transpose(inp, perm)
+ with self.test_session(use_gpu=True):
+ inx = ops.convert_to_tensor(inp)
+ y = array_ops.transpose(inx, perm)
+ tf_ans = y.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, y)
+ self._ClearCachedSession()
+
def testNop(self):
self._compareCpu(np.arange(0, 6).reshape([3, 2]).astype(np.float32), [0, 1])
diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py
new file mode 100644
index 0000000000..6992fa57ea
--- /dev/null
+++ b/tensorflow/python/ops/conv2d_benchmark.py
@@ -0,0 +1,141 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for Conv2D op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import time
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def build_graph(device, input_shape, filter_shape, strides, padding, num_iters):
+ """builds a graph containing a sequence of conv2d operations.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use.
+ num_iters: number of iterations to run conv2d.
+
+ Returns:
+ An array of tensors to run()
+ """
+ with ops.device("/%s:0" % device):
+ inp = variables.Variable(random_ops.truncated_normal(input_shape))
+ filt = variables.Variable(random_ops.truncated_normal(filter_shape))
+
+ outputs = []
+ conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC")
+ outputs.append(conv2d_op)
+ for _ in range(1, num_iters):
+ with ops.control_dependencies([conv2d_op]):
+ conv2d_op = nn_ops.conv2d(
+ inp, filt, strides, padding, data_format="NHWC")
+ outputs.append(conv2d_op)
+ return control_flow_ops.group(*outputs)
+
+
+class Conv2DBenchmark(test.Benchmark):
+ """Benchmark conv2d!"""
+
+ def _run_graph(self, device, input_shape, filter_shape, strides, padding,
+ num_iters):
+ """runs the graph and print its execution time.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use. num_iters: Number of iterations to run the
+ benchmark.
+ num_iters: number of iterations to run conv2d.
+
+ Returns:
+ The duration of the run in seconds.
+ """
+ graph = ops.Graph()
+ with graph.as_default():
+ outputs = build_graph(device, input_shape, filter_shape, strides, padding,
+ num_iters)
+ with session_lib.Session(graph=graph) as session:
+ variables.global_variables_initializer().run()
+ # warmup runs
+ session.run(outputs)
+
+ start_time = time.time()
+ session.run(outputs)
+ duration = (time.time() - start_time) / num_iters
+
+ print("%s inputshape:%s filtershape:%s strides:%s padding:%s "
+ "%d iters: %.8f sec" %
+ (device, str(input_shape).replace(" ", ""),
+ str(filter_shape).replace(" ", ""),
+ str(strides).replace(" ", ""), padding, num_iters, duration))
+
+ name_template = (
+ "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_"
+ "strides_{strides}_padding_{padding}")
+
+ self.report_benchmark(
+ name=name_template.format(
+ device=device,
+ inputshape=str(input_shape).replace(" ", ""),
+ filtershape=str(filter_shape).replace(" ", ""),
+ strides=str(strides).replace(" ", ""),
+ padding=padding).replace(" ", ""),
+ iters=num_iters,
+ wall_time=duration / num_iters)
+
+ return duration
+
+ def benchmark_conv2d(self):
+ print("conv2d benchmark:")
+
+ h = 500
+ w = 500
+ fh = 3
+ fw = 3
+ input_shapes = []
+ filter_shapes = []
+ for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]):
+ input_shapes += [[b, h, w, c]]
+ filter_shapes += [[fh, fw, c, b]]
+ strides = [[1, 2, 2, 1]]
+ paddings = ["VALID", "SAME"]
+ for ishape, fshape in zip(input_shapes, filter_shapes):
+ for stride in strides:
+ for padding in paddings:
+ self._run_graph("gpu", ishape, fshape, stride, padding, 80)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/transpose_benchmark.py b/tensorflow/python/ops/transpose_benchmark.py
index 63a314295e..6b5f0f20d8 100644
--- a/tensorflow/python/ops/transpose_benchmark.py
+++ b/tensorflow/python/ops/transpose_benchmark.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
def build_graph(device, input_shape, perm, datatype, num_iters):
- """Build a graph containing a sequence of conv2d operations.
+ """builds a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
@@ -50,10 +50,12 @@ def build_graph(device, input_shape, perm, datatype, num_iters):
t = constant_op.constant(inp, shape=input_shape)
outputs = []
- outputs.append(array_ops.transpose(t, perm))
- for i in range(1, num_iters):
- with ops.control_dependencies([outputs[i - 1]]):
- outputs.append(array_ops.transpose(t, perm))
+ transpose_op = array_ops.transpose(t, perm)
+ outputs.append(transpose_op)
+ for _ in range(1, num_iters):
+ with ops.control_dependencies([transpose_op]):
+ transpose_op = array_ops.transpose(t, perm)
+ outputs.append(transpose_op)
return control_flow_ops.group(*outputs)
@@ -61,7 +63,7 @@ class TransposeBenchmark(test.Benchmark):
"""Benchmark transpose!"""
def _run_graph(self, device, input_shape, perm, num_iters, datatype):
- """Run the graph and print its execution time.
+ """runs the graph and print its execution time.
Args:
device: String, the device to run on.
@@ -82,9 +84,11 @@ class TransposeBenchmark(test.Benchmark):
session.run(outputs)
start_time = time.time()
session.run(outputs)
+
duration = (time.time() - start_time) / num_iters
throughput = np.prod(
np.array(input_shape)) * datatype().itemsize * 2 / duration / 1e9
+
print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." %
(device, str(datatype), str(input_shape).replace(" ", ""),
str(perm).replace(" ", ""), num_iters, duration, throughput))
@@ -108,12 +112,12 @@ class TransposeBenchmark(test.Benchmark):
datatypes = [np.complex128, np.float64, np.float32, np.float16, np.int8]
- small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + [[
- 2, 100, 100, 16
- ], [2, 16, 100, 100]] * 2 + [[2, 5000, 16], [2, 16, 5000]] * 2
- small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
- [0, 3, 1, 2], [0, 2, 3, 1]
- ] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
+ small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2
+ small_shapes += [[2, 100, 100, 16], [2, 16, 100, 100]] * 2
+ small_shapes += [[2, 5000, 16], [2, 16, 5000]] * 2
+ small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
+ small_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
+ small_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
large_shapes = [[2, 40, 40, 40, 32], [2, 40, 40, 40, 64]] * 2 + [[
2, 300, 300, 32
@@ -132,5 +136,23 @@ class TransposeBenchmark(test.Benchmark):
for ishape, perm in zip(large_shapes, large_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
+ small_dim_large_shapes = [[2, 10000, 3], [2, 3, 10000], [2, 10000, 8],
+ [2, 8, 10000]]
+ small_dim_small_shapes = [[2, 5000, 3], [2, 3, 5000], [2, 5000, 8],
+ [2, 8, 5000]]
+ small_dim_perms = [[0, 2, 1]] * 4
+
+ num_iters = 320
+ small_dim_large_shape_datatypes = [np.float64, np.float32, np.int8]
+ for datatype in small_dim_large_shape_datatypes:
+ for ishape, perm in zip(small_dim_large_shapes, small_dim_perms):
+ self._run_graph("gpu", ishape, perm, num_iters, datatype)
+
+ small_dim_small_shape_datatypes = [np.complex128, np.float16]
+ for datatype in small_dim_small_shape_datatypes:
+ for ishape, perm in zip(small_dim_small_shapes, small_dim_perms):
+ self._run_graph("gpu", ishape, perm, num_iters, datatype)
+
+
if __name__ == "__main__":
test.main()