aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-19 13:15:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 13:19:10 -0700
commit831a3a46d34e9bd305555889392c016690ff9bc4 (patch)
tree454bf198209e43c411c62ec02a7a85a426434c92
parent03a790d5bff835365a54a534fc1658dc287c0f63 (diff)
Automated g4 rollback of changelist 160314706
PiperOrigin-RevId: 162525519
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc127
-rw-r--r--tensorflow/python/BUILD20
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py20
-rw-r--r--tensorflow/python/ops/conv2d_benchmark.py141
-rw-r--r--tensorflow/python/ops/transpose_benchmark.py64
6 files changed, 348 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index 2307c2de0e..bcabda848c 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -272,6 +272,87 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
}
}
+// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
+// when only one of the dimension sizes is smaller than 16,
+// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
+//
+// small_dim = the_smaller_dimension_size
+// large_dim = the_larger_dimension_size
+// tile_num_per_block = blockDim.x
+// kTileLength = small_dim
+//
+// Each thread block operates on a single rectangle tile, where its width is
+// kTileLength (we currently set it to 64) and its height is small_dim,
+// We set the thread block's X dimension to be tile_num_per_block, and its Y
+// and Z to be one.
+template <typename T, bool SmallDim2>
+__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
+ int batch_per_block,
+ Dimension<3> input_dims,
+ T* output) {
+ // TODO(yangzihao) avoid share memory bank conflict.
+ extern __shared__ __align__(sizeof(T)) unsigned char shmem[];
+ T* shared_memory_tile = reinterpret_cast<T*>(shmem);
+
+ eigen_assert(blockDim.y == 1);
+ eigen_assert(blockDim.z == 1);
+ eigen_assert(gridDim.z == 1);
+
+ int block_offset = blockIdx.x * blockDim.x * batch_per_block;
+
+ int x = threadIdx.x;
+ int tile_height = blockDim.x;
+
+ // Get tile height, width, and thread/block origin indices.
+ int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
+ int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
+
+ int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
+ (SmallDim2 ? block_offset * small_dim : block_offset) + x;
+ if (global_offset > (input_dims[0] * input_dims[1] * input_dims[2])) return;
+
+ for (int batch = 0; batch < batch_per_block; ++batch) {
+ int block_origin_idx =
+ small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
+ int thread_origin_idx =
+ block_origin_idx +
+ (SmallDim2 ? block_offset * small_dim : block_offset) + x;
+
+ if (block_offset + blockDim.x > large_dim) {
+ tile_height = large_dim - block_offset;
+ }
+
+ // Load a continuous memory region to shared memory tile.
+ if (x < tile_height) {
+ for (int y = 0; y < small_dim; y++) {
+ int shmem_index =
+ SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
+ shared_memory_tile[shmem_index] =
+ ldg(input + thread_origin_idx +
+ y * (SmallDim2 ? tile_height : large_dim));
+ }
+ }
+
+ __syncthreads();
+
+ // Get block origin index for output array.
+ int output_block_offset = block_origin_idx;
+ int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
+ int output_block_origin_idx = output_block_offset + output_block_idx;
+
+ // Store the tranposed memory region in shared memory to device.
+ if (x < tile_height) {
+ for (int y = 0; y < small_dim; y++) {
+ int output_idx = output_block_origin_idx + x +
+ y * (SmallDim2 ? large_dim : tile_height);
+ int shmem_index =
+ SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
+ output[output_idx] = shared_memory_tile[shmem_index];
+ }
+ }
+ }
+}
+
// A Cuda custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
@@ -420,25 +501,63 @@ template <typename T>
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
const Dimension<3>& input_dims, T* output) {
// If both dimensions are not trivial, use tiles for the actual swapping.
+ // If one dimension is trivial, use SmallDim kernel for swapping.
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
static const int kMinDimensionToUseTiles = 16;
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles);
+ bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
+ input_dims[2] < kMinDimensionToUseTiles)) ||
+ ((input_dims[1] < kMinDimensionToUseTiles &&
+ input_dims[2] >= kMinDimensionToUseTiles));
+ static const int NumSubTiles = 8;
+
if (use_tiles) {
- // We get best performance when TileSize is the number of threads in a warp
- // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
- // threads.
static const int TileSize = 32;
- static const int NumSubTiles = 8;
Dimension<3> input_dims_in_tiles = {
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
(input_dims[2] + TileSize - 1) / TileSize,
};
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
+ // We get best performance when TileSize is the number of threads in a warp
+ // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
+ // threads.
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
input, input_dims, output);
+ } else if (use_small_dim) {
+ // When only one of the dimensions is smaller than kMinDimensionToUseTiles,
+ // we use one block to process a rectangle region with the size of
+ // kTileLength * small_dim. We found that when set kTileLength to 64 on
+ // TitanX Maxwell GPU, it achieves the best performance.
+ // large_dim
+ // +---------------...--------+
+ // | | | |
+ // small_dim | | ... | |
+ // | | | |
+ // +--------------...---------+
+ // \----- ------/ \- -/
+ // V V
+ // kTileLength(tile_height) tile_height
+ static const int kTileLength = 64;
+ static const int kGridDimY = 2048;
+ int small_dim = std::min(input_dims[2], input_dims[1]);
+ int large_dim = std::max(input_dims[2], input_dims[1]);
+ int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
+ int grid_dim_y = input_dims[0] < kGridDimY ? input_dims[0] : kGridDimY;
+ int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
+ if (input_dims[2] < input_dims[1]) {
+ SwapDimension1And2InTensor3SmallDim<T, true>
+ <<<dim3(tile_num_per_block, grid_dim_y), kTileLength,
+ kTileLength * small_dim * sizeof(T), d.stream()>>>(
+ input, batch_per_block, input_dims, output);
+ } else {
+ SwapDimension1And2InTensor3SmallDim<T, false>
+ <<<dim3(tile_num_per_block, grid_dim_y), kTileLength,
+ kTileLength * small_dim * sizeof(T), d.stream()>>>(
+ input, batch_per_block, input_dims, output);
+ }
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index aa66f25391..b7063d59ba 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3824,6 +3824,26 @@ cuda_py_test(
)
cuda_py_test(
+ name = "conv2d_benchmark",
+ size = "large",
+ srcs = ["ops/conv2d_benchmark.py"],
+ additional_deps = [
+ ":client",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":nn_ops",
+ ":platform",
+ ":platform_benchmark",
+ ":random_ops",
+ ":variables",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ ],
+ main = "ops/conv2d_benchmark.py",
+)
+
+cuda_py_test(
name = "split_benchmark",
srcs = ["ops/split_benchmark.py"],
additional_deps = [
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 7363279a49..442b4b2df5 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1985,7 +1985,7 @@ cuda_py_test(
cuda_py_test(
name = "transpose_op_test",
- size = "large",
+ size = "small",
srcs = ["transpose_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 570fa79944..67bdb4237d 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -229,6 +229,26 @@ class TransposeTest(test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
+ def testLargeSizeGPU(self):
+ # If no GPU available, skip the test
+ if not test.is_gpu_available(cuda_only=True):
+ return
+
+ large_shapes = [[1000000, 31, 3], [3, 1000000, 31], [3, 31, 1000000],
+ [2, 1000, 1000], [1000, 2, 1000], [1000, 1000, 2]]
+ perms = [[0, 2, 1]] * 6
+
+ for input_shape, perm in zip(large_shapes, perms):
+ total_size = np.prod(input_shape)
+ inp = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_shape)
+ np_ans = self._np_transpose(inp, perm)
+ with self.test_session(use_gpu=True):
+ inx = ops.convert_to_tensor(inp)
+ y = array_ops.transpose(inx, perm)
+ tf_ans = y.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, y)
+
def testNop(self):
self._compareCpu(np.arange(0, 6).reshape([3, 2]).astype(np.float32), [0, 1])
diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py
new file mode 100644
index 0000000000..6992fa57ea
--- /dev/null
+++ b/tensorflow/python/ops/conv2d_benchmark.py
@@ -0,0 +1,141 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for Conv2D op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import time
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def build_graph(device, input_shape, filter_shape, strides, padding, num_iters):
+ """builds a graph containing a sequence of conv2d operations.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use.
+ num_iters: number of iterations to run conv2d.
+
+ Returns:
+ An array of tensors to run()
+ """
+ with ops.device("/%s:0" % device):
+ inp = variables.Variable(random_ops.truncated_normal(input_shape))
+ filt = variables.Variable(random_ops.truncated_normal(filter_shape))
+
+ outputs = []
+ conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC")
+ outputs.append(conv2d_op)
+ for _ in range(1, num_iters):
+ with ops.control_dependencies([conv2d_op]):
+ conv2d_op = nn_ops.conv2d(
+ inp, filt, strides, padding, data_format="NHWC")
+ outputs.append(conv2d_op)
+ return control_flow_ops.group(*outputs)
+
+
+class Conv2DBenchmark(test.Benchmark):
+ """Benchmark conv2d!"""
+
+ def _run_graph(self, device, input_shape, filter_shape, strides, padding,
+ num_iters):
+ """runs the graph and print its execution time.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use. num_iters: Number of iterations to run the
+ benchmark.
+ num_iters: number of iterations to run conv2d.
+
+ Returns:
+ The duration of the run in seconds.
+ """
+ graph = ops.Graph()
+ with graph.as_default():
+ outputs = build_graph(device, input_shape, filter_shape, strides, padding,
+ num_iters)
+ with session_lib.Session(graph=graph) as session:
+ variables.global_variables_initializer().run()
+ # warmup runs
+ session.run(outputs)
+
+ start_time = time.time()
+ session.run(outputs)
+ duration = (time.time() - start_time) / num_iters
+
+ print("%s inputshape:%s filtershape:%s strides:%s padding:%s "
+ "%d iters: %.8f sec" %
+ (device, str(input_shape).replace(" ", ""),
+ str(filter_shape).replace(" ", ""),
+ str(strides).replace(" ", ""), padding, num_iters, duration))
+
+ name_template = (
+ "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_"
+ "strides_{strides}_padding_{padding}")
+
+ self.report_benchmark(
+ name=name_template.format(
+ device=device,
+ inputshape=str(input_shape).replace(" ", ""),
+ filtershape=str(filter_shape).replace(" ", ""),
+ strides=str(strides).replace(" ", ""),
+ padding=padding).replace(" ", ""),
+ iters=num_iters,
+ wall_time=duration / num_iters)
+
+ return duration
+
+ def benchmark_conv2d(self):
+ print("conv2d benchmark:")
+
+ h = 500
+ w = 500
+ fh = 3
+ fw = 3
+ input_shapes = []
+ filter_shapes = []
+ for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]):
+ input_shapes += [[b, h, w, c]]
+ filter_shapes += [[fh, fw, c, b]]
+ strides = [[1, 2, 2, 1]]
+ paddings = ["VALID", "SAME"]
+ for ishape, fshape in zip(input_shapes, filter_shapes):
+ for stride in strides:
+ for padding in paddings:
+ self._run_graph("gpu", ishape, fshape, stride, padding, 80)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/transpose_benchmark.py b/tensorflow/python/ops/transpose_benchmark.py
index 6bd3fe5e5a..cddefacf2e 100644
--- a/tensorflow/python/ops/transpose_benchmark.py
+++ b/tensorflow/python/ops/transpose_benchmark.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
def build_graph(device, input_shape, perm, datatype, num_iters):
- """Build a graph containing a sequence of conv2d operations.
+ """builds a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
@@ -50,10 +50,12 @@ def build_graph(device, input_shape, perm, datatype, num_iters):
t = constant_op.constant(inp, shape=input_shape)
outputs = []
- outputs.append(array_ops.transpose(t, perm))
- for i in range(1, num_iters):
- with ops.control_dependencies([outputs[i - 1]]):
- outputs.append(array_ops.transpose(t, perm))
+ transpose_op = array_ops.transpose(t, perm)
+ outputs.append(transpose_op)
+ for _ in range(1, num_iters):
+ with ops.control_dependencies([transpose_op]):
+ transpose_op = array_ops.transpose(t, perm)
+ outputs.append(transpose_op)
return control_flow_ops.group(*outputs)
@@ -61,7 +63,7 @@ class TransposeBenchmark(test.Benchmark):
"""Benchmark transpose!"""
def _run_graph(self, device, input_shape, perm, num_iters, datatype):
- """Run the graph and print its execution time.
+ """runs the graph and print its execution time.
Args:
device: String, the device to run on.
@@ -82,9 +84,11 @@ class TransposeBenchmark(test.Benchmark):
session.run(outputs)
start_time = time.time()
session.run(outputs)
+
duration = (time.time() - start_time) / num_iters
- throughput = np.prod(np.array(
- input_shape)) * datatype().itemsize * 2 / duration / 1e9
+ throughput = np.prod(
+ np.array(input_shape)) * datatype().itemsize * 2 / duration / 1e9
+
print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." %
(device, str(datatype), str(input_shape).replace(" ", ""),
str(perm).replace(" ", ""), num_iters, duration, throughput))
@@ -108,19 +112,19 @@ class TransposeBenchmark(test.Benchmark):
datatypes = [np.complex128, np.float64, np.float32, np.float16, np.int8]
- small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + [[
- 2, 100, 100, 16
- ], [2, 16, 100, 100]] * 2 + [[2, 5000, 16], [2, 16, 5000]] * 2
- small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
- [0, 3, 1, 2], [0, 2, 3, 1]
- ] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
+ small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2
+ small_shapes += [[2, 100, 100, 16], [2, 16, 100, 100]] * 2
+ small_shapes += [[2, 5000, 16], [2, 16, 5000]] * 2
+ small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
+ small_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
+ small_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
- large_shapes = [[2, 100, 100, 100, 32], [2, 100, 100, 100, 64]] * 2 + [[
- 2, 1000, 1000, 32
- ], [2, 1000, 1000, 64]] * 2 + [[2, 1000000, 32], [2, 1000000, 64]] * 2
- large_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
- [0, 3, 1, 2], [0, 2, 3, 1]
- ] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
+ large_shapes = [[2, 100, 100, 100, 32], [2, 100, 100, 100, 64]] * 2
+ large_shapes += [[2, 1000, 1000, 32], [2, 1000, 1000, 64]] * 2
+ large_shapes += [[2, 1000000, 32], [2, 1000000, 64]] * 2
+ large_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
+ large_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
+ large_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
huge_shapes = [[2, 100, 100, 100, 128], [2, 1000, 1000, 128],
[2, 1000000, 128]] * 2
@@ -143,5 +147,23 @@ class TransposeBenchmark(test.Benchmark):
for ishape, perm in zip(huge_shapes, huge_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
+ small_dim_large_shapes = [[2, 1000000, 3], [2, 3, 1000000], [2, 1000000, 8],
+ [2, 8, 1000000]]
+ small_dim_small_shapes = [[2, 5000, 3], [2, 3, 5000], [2, 5000, 8],
+ [2, 8, 5000]]
+ small_dim_perms = [[0, 2, 1]] * 4
+
+ num_iters = 320
+ small_dim_large_shape_datatypes = [np.float64, np.float32, np.int8]
+ for datatype in small_dim_large_shape_datatypes:
+ for ishape, perm in zip(small_dim_large_shapes, small_dim_perms):
+ self._run_graph("gpu", ishape, perm, num_iters, datatype)
+
+ small_dim_small_shape_datatypes = [np.complex128, np.float16]
+ for datatype in small_dim_small_shape_datatypes:
+ for ishape, perm in zip(small_dim_small_shapes, small_dim_perms):
+ self._run_graph("gpu", ishape, perm, num_iters, datatype)
+
+
if __name__ == "__main__":
test.main()