aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-20 09:32:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-20 09:39:27 -0700
commit4bbd9bd11fb52ebe0e3de6f8553a2372c13146bb (patch)
treee7d5c739afc7cdb7fc1945e9e4948e03edadee1f /tensorflow/contrib/fused_conv
parentb3451058a25201c50573f68556812e51cff56edb (diff)
Add fused_conv2d_bias_activation operator for the forward phase.
PiperOrigin-RevId: 162624917
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r--tensorflow/contrib/fused_conv/BUILD168
-rw-r--r--tensorflow/contrib/fused_conv/__init__.py25
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc497
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h62
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc72
-rw-r--r--tensorflow/contrib/fused_conv/python/__init__.py19
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py243
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py87
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py573
9 files changed, 1746 insertions, 0 deletions
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
new file mode 100644
index 0000000000..026ee3df07
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -0,0 +1,168 @@
+# Description:
+# A Fused Conv Bias Activation operator wrapper.
+# APIs are meant to change over time.
+package(
+ default_visibility = ["//visibility:private"],
+ features = ["-parse_headers"],
+)
+
+package_group(
+ name = "friends",
+ packages = [
+ "//tensorflow/...",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
+tf_custom_op_py_library(
+ name = "fused_conv_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ dso = [":python/ops/_fused_conv2d_bias_activation_op.so"],
+ kernels = [
+ ":fused_conv2d_bias_activation_op_kernels",
+ ":fused_conv2d_bias_activation_op_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":fused_conv2d_bias_activation_op",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_kernel_library(
+ name = "fused_conv2d_bias_activation_op_kernels",
+ srcs = [
+ "kernels/fused_conv2d_bias_activation_op.cc",
+ "kernels/fused_conv2d_bias_activation_op.h",
+ ],
+ prefix = "fused_conv2d_bias_activation_op",
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core/kernels:bounds_check_lib",
+ "//tensorflow/core/kernels:conv_2d_hdrs",
+ "//tensorflow/core/kernels:conv_ops_gpu_hdrs",
+ "//tensorflow/core/kernels:ops_util_hdrs",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "python/ops/_fused_conv2d_bias_activation_op.so",
+ srcs = [
+ "kernels/fused_conv2d_bias_activation_op.cc",
+ "kernels/fused_conv2d_bias_activation_op.h",
+ "ops/fused_conv2d_bias_activation_op.cc",
+ ],
+ deps = [
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core/kernels:bounds_check_lib",
+ "//tensorflow/core/kernels:conv_2d_hdrs",
+ "//tensorflow/core/kernels:conv_ops_gpu_hdrs",
+ "//tensorflow/core/kernels:ops_util_hdrs",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "fused_conv2d_bias_activation_op",
+ ],
+ deps = [
+ "//tensorflow/core:lib_proto_parsing",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "fused_conv2d_bias_activation_op",
+ deps = [":fused_conv2d_bias_activation_op_op_lib"],
+)
+
+cuda_py_test(
+ name = "fused_conv2d_bias_activation_op_test",
+ size = "small",
+ srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
+ additional_deps = [
+ ":fused_conv_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "manual",
+ "requires_cudnn6",
+ ],
+)
+
+cuda_py_test(
+ name = "fused_conv2d_bias_activation_benchmark",
+ size = "large",
+ srcs = ["python/ops/fused_conv2d_bias_activation_benchmark.py"],
+ additional_deps = [
+ ":fused_conv_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_benchmark",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ ],
+ main = "python/ops/fused_conv2d_bias_activation_benchmark.py",
+ tags = [
+ "manual",
+ "requires_cudnn6",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/fused_conv/__init__.py b/tensorflow/contrib/fused_conv/__init__.py
new file mode 100644
index 0000000000..dd4d3fc707
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to fused_conv2d_bias_activation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.fused_conv.python.ops.fused_conv2d_bias_activation_op import *
+from tensorflow.python.util.all_util import remove_undocumented
+
+remove_undocumented(__name__, ['fused_conv2d_bias_activation'])
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
new file mode 100644
index 0000000000..d553d5a0a6
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -0,0 +1,497 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/use_cudnn.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/util/activation_mode.h"
+#endif // GOOGLE_CUDA
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+struct LaunchConvOp;
+
+template <typename Device, typename T>
+class FusedConv2DBiasActivationOp : public OpKernel {
+ public:
+ explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(context,
+ (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW),
+ errors::InvalidArgument("Current implementation only supports "
+ "NHWC and NCHW data formats."));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(strides_, data_format_, 'N') == 1 &&
+ GetTensorDim(strides_, data_format_, 'C') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ string activation_mode_str;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("activation_mode", &activation_mode_str));
+ OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str,
+ &activation_mode_));
+ OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
+ errors::InvalidArgument("Current implementation only supports "
+ "relu as the activation mode."));
+ cudnn_use_autotune_ = CudnnUseAutotune();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is one of the following shapes:
+ // [ batch, in_rows, in_cols, in_depth ] (for NHWC data format)
+ // [ batch, in_depth, in_rows, in_cols ] (for NCHW data format)
+ const Tensor& input = context->input(0);
+
+ // Input filter is of the following dimensions:
+ // [ filter_rows, filter_cols, in_depth, out_depth ]
+ const Tensor& filter = context->input(1);
+
+ // Input bias is a 1-D tensor the size of the last
+ // dimension of Output tensor
+ const Tensor& bias = context->input(2);
+
+ // For 2D convolution, there should be 4 dimensions.
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+
+ // Bias should be a 1-D tensor.
+ OP_REQUIRES(context, bias.dims() == 1,
+ errors::InvalidArgument("bias must be 1-dimensional: ",
+ bias.shape().DebugString()));
+
+ for (int i = 0; i < 4; i++) {
+ OP_REQUIRES(context,
+ FastBoundsCheck(filter.dim_size(i),
+ std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("filter dimension too large"));
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(input.dim_size(i), std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("input dimension too large"));
+ }
+
+ // The last dimension for input is in_depth. It must be the same as the
+ // filter's in_depth.
+ const int64 in_depth = GetTensorDim(input, data_format_, 'C');
+ OP_REQUIRES(context, in_depth == filter.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth,
+ " vs ", filter.dim_size(2)));
+
+ // The last dimension for filter is out_depth.
+ const int32 out_depth = static_cast<int32>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
+ const int32 input_rows = static_cast<int32>(input_rows_raw);
+ const int32 filter_rows = static_cast<int32>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
+ const int32 input_cols = static_cast<int32>(input_cols_raw);
+ const int32 filter_cols = static_cast<int32>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
+ const int32 batch = static_cast<int32>(batch_raw);
+
+ // For now we take the stride from the second and third dimensions only (we
+ // do not support striding on the batch or depth dimension).
+ const int32 stride_rows =
+ static_cast<int32>(GetTensorDim(strides_, data_format_, 'H'));
+ const int32 stride_cols =
+ static_cast<int32>(GetTensorDim(strides_, data_format_, 'W'));
+ const int32 bias_size = static_cast<int32>(bias.dim_size(0));
+
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+ // Output tensor is of the following dimensions:
+ // [ in_batch, out_rows, out_cols, out_depth ]
+ TensorShape out_shape =
+ ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // Bias size should be the same as the size of the channel dimension of
+ // output.
+ OP_REQUIRES(context, bias_size == out_depth,
+ errors::InvalidArgument(
+ "bias size should equal the channel "
+ "dimension size of output. bias shape: ",
+ bias.shape().DebugString() +
+ ", output shape: " + output->shape().DebugString()));
+
+ VLOG(2) << "FusedConv2DBiasActivation: in_depth = " << in_depth
+ << ", input_cols = " << input_cols
+ << ", filter_cols = " << filter_cols
+ << ", input_rows = " << input_rows
+ << ", filter_rows = " << filter_rows
+ << ", stride_rows = " << stride_rows
+ << ", stride_cols = " << stride_cols
+ << ", bias_size = " << bias_size << ", out_depth = " << out_depth;
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+ launcher_.launch(context, cudnn_use_autotune_, input, filter, stride_rows,
+ stride_cols, bias, activation_mode_,
+ BrainPadding2EigenPadding(padding_), data_format_, output);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ ActivationMode activation_mode_;
+ TensorFormat data_format_;
+ LaunchFusedConv2DBiasActivationOp<Device, T> launcher_;
+ bool cudnn_use_autotune_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp);
+};
+
+#if GOOGLE_CUDA
+namespace dnn = ::perftools::gputools::dnn;
+
+dnn::ActivationMode BrainActivationMode2CudnnActivationMode(
+ ActivationMode activation_mode) {
+ switch (activation_mode) {
+ case ActivationMode::SIGMOID:
+ return dnn::ActivationMode::kSigmoid;
+ case ActivationMode::RELU:
+ return dnn::ActivationMode::kRelu;
+ case ActivationMode::RELUX:
+ return dnn::ActivationMode::kReluX;
+ case ActivationMode::RELU6:
+ return dnn::ActivationMode::kRelu6;
+ case ActivationMode::TANH:
+ return dnn::ActivationMode::kTanh;
+ case ActivationMode::BANDPASS:
+ return dnn::ActivationMode::kBandPass;
+ }
+ // Prevent compiler warning about missing return
+ return dnn::ActivationMode::kRelu;
+}
+
+// A dummy type to group forward convolution autotune results together.
+struct ConvBiasActivationAutoTuneGroup {
+ static string name() { return "ConvBiasActivation"; }
+};
+typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, ConvParameters,
+ perftools::gputools::dnn::AlgorithmConfig>
+ AutoTuneConvBiasActivation;
+
+template <typename T>
+void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
+ OpKernelContext* ctx, bool cudnn_use_autotune, const Tensor& input_param,
+ const Tensor& filter, int32 row_stride, int32 col_stride,
+ const Tensor& bias, const ActivationMode& activation_mode,
+ const Eigen::PaddingType& padding, TensorFormat data_format,
+ Tensor* output) {
+ using perftools::gputools::dnn::AlgorithmConfig;
+ using perftools::gputools::dnn::AlgorithmType;
+ using perftools::gputools::dnn::ProfileResult;
+ using perftools::gputools::dnn::kDefaultAlgorithm;
+ auto* stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ Tensor input = input_param;
+
+ perftools::gputools::dnn::ActivationMode cudnn_activation_mode =
+ BrainActivationMode2CudnnActivationMode(activation_mode);
+
+ // TODO(yangzihao): refactor all the complicated/duplicated code in regular
+ // conv ops to a shared conv utility.
+ int32 padding_rows = 0;
+ int32 padding_cols = 0;
+ const int64 in_batch = GetTensorDim(input, data_format, 'N');
+ int64 in_rows = GetTensorDim(input, data_format, 'H');
+ int64 in_cols = GetTensorDim(input, data_format, 'W');
+ const int64 in_depths = GetTensorDim(input, data_format, 'C');
+ const int64 out_batch = GetTensorDim(*output, data_format, 'N');
+ const int64 out_rows = GetTensorDim(*output, data_format, 'H');
+ const int64 out_cols = GetTensorDim(*output, data_format, 'W');
+ const int64 out_depths = GetTensorDim(*output, data_format, 'C');
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
+ if (padding == Eigen::PADDING_SAME) {
+ // Total padding on rows and cols is
+ // Pr = (R' - 1) * S + Kr - R
+ // Pc = (C' - 1) * S + Kc - C
+ // where (R', C') are output dimensions, (R, C) are input dimensions, S
+ // is stride, (Kr, Kc) are filter dimensions.
+ // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
+ // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
+ // we pad more on the right and bottom than on the top and left.
+ padding_rows =
+ std::max<int32>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+ padding_cols =
+ std::max<int32>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
+ const int rows_parity = padding_rows & 1;
+ const int cols_parity = padding_cols & 1;
+ if ((rows_parity | cols_parity) != 0) {
+ Tensor transformed_input;
+ int64 new_in_rows = in_rows + rows_parity;
+ int64 new_in_cols = in_cols + cols_parity;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format, in_batch, new_in_rows,
+ new_in_cols, in_depths),
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
+ {{0, 0}}, {{rows_parity, cols_parity}},
+ To32Bit(transformed_input.tensor<T, 4>()), data_format);
+
+ input = transformed_input;
+ in_rows = new_in_rows;
+ in_cols = new_in_cols;
+ }
+ }
+
+ if (data_format == FORMAT_NHWC) {
+ // Convert the input tensor from NHWC to NCHW.
+ TensorShape nchw_shape =
+ ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
+ if (in_depths > 1) {
+ Tensor transformed_input;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ nchw_shape, &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(input).tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
+ input = transformed_input;
+ } else {
+ // If depth <= 1, then just reshape.
+ CHECK(input.CopyFrom(input, nchw_shape));
+ }
+ }
+
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(in_batch)
+ .set_feature_map_count(in_depths)
+ .set_height(in_rows)
+ .set_width(in_cols)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(out_batch)
+ .set_height(out_rows)
+ .set_width(out_cols)
+ .set_feature_map_count(out_depths)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter.dim_size(0))
+ .set_input_filter_width(filter.dim_size(1))
+ .set_input_feature_map_count(filter.dim_size(2))
+ .set_output_feature_map_count(filter.dim_size(3));
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(row_stride)
+ .set_horizontal_filter_stride(col_stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ // Shuffles a filter tensor from:
+ // [<spatial_dims>, in, out]
+ // to:
+ // [out, in, <spatial_dims>]
+ // TODO(yangzihao): Support a data layout tag for the filter weights, and only
+ // do the transform if the weights are not already in the correct layout.
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({filter.dim_size(3), filter.dim_size(2),
+ filter.dim_size(0), filter.dim_size(1)}),
+ &transformed_filter));
+
+ functor::TransformFilter<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ To32Bit(transformed_filter.tensor<T, 4>()));
+
+ Tensor transformed_output;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
+ out_cols, out_depths),
+ &transformed_output));
+
+ auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto output_ptr =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+
+ auto bias_ptr = AsDeviceMemory(bias.template flat<T>().data(),
+ bias.template flat<T>().size());
+
+ static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
+ // default value is in bytes despite the name of the environment variable
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
+ );
+
+ int device_id = stream->parent()->device_ordinal();
+ DataType dtype = input.dtype();
+ ConvParameters conv_parameters = {
+ in_batch,
+ in_depths,
+ {{in_rows, in_cols}},
+ out_depths,
+ {{patch_rows, patch_cols}},
+ {{row_stride, col_stride}},
+ {{padding_rows, padding_cols}},
+ dtype,
+ device_id,
+ };
+
+ AlgorithmConfig algorithm_config;
+ if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
+ conv_parameters, &algorithm_config)) {
+ std::vector<AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
+ ProfileResult best_result;
+ ProfileResult best_result_no_scratch;
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
+ }
+ }
+ }
+ }
+ OP_REQUIRES(
+ ctx,
+ best_result.is_valid() && best_result.algorithm() != kDefaultAlgorithm,
+ errors::NotFound("No algorithm worked!"));
+ OP_REQUIRES(ctx,
+ best_result_no_scratch.is_valid() &&
+ best_result_no_scratch.algorithm() != kDefaultAlgorithm,
+ errors::NotFound("No algorithm without scratch worked!"));
+ algorithm_config.set_algorithm(best_result.algorithm());
+ algorithm_config.set_algorithm_no_scratch(
+ best_result_no_scratch.algorithm());
+ AutoTuneConvBiasActivation::GetInstance()->Insert(conv_parameters,
+ algorithm_config);
+ }
+
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
+ &scratch_allocator, algorithm_config,
+ /*output_profile_result=*/nullptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN launch failure : input shape(", input.shape().DebugString(),
+ ") filter shape(", filter.shape().DebugString(), ")"));
+ }
+
+ // Convert the output tensor back from NCHW to NHWC.
+ if (data_format == FORMAT_NHWC) {
+ functor::NCHWToNHWC<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
+ output->tensor<T, 4>());
+ } else {
+ *output = transformed_output;
+ }
+}
+
+// Registration of the GPU implementations.
+REGISTER_KERNEL_BUILDER(Name("FusedConv2DBiasActivation")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ FusedConv2DBiasActivationOp<GPUDevice, float>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
new file mode 100644
index 0000000000..d71b26cf1d
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
+#define THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/util/activation_mode.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#if GOOGLE_CUDA
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+// Forward declaration.
+class OpKernelContext;
+
+template <typename Device, typename T>
+class LaunchFusedConv2DBiasActivationOp {
+ public:
+ void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Tensor& bias,
+ const ActivationMode& activation_mode,
+ const Eigen::PaddingType& padding, TensorFormat data_format,
+ Tensor* output);
+};
+
+#ifdef GOOGLE_CUDA
+template <typename T>
+class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T> {
+ public:
+ void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int32 row_stride,
+ int32 col_stride, const Tensor& bias,
+ const ActivationMode& activation_mode,
+ const Eigen::PaddingType& padding, TensorFormat data_format,
+ Tensor* output);
+};
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
new file mode 100644
index 0000000000..6134c5c699
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/activation_mode.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+namespace {
+// Return the string containing the list of valid activation modes, that can be
+// used as an Attr() in REGISTER_OP.
+string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; }
+
+} // namespace
+
+// --------------------------------------------------------------------------
+REGISTER_OP("FusedConv2DBiasActivation")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("bias: T")
+ .Output("output: T")
+ .Attr("T: {float}")
+ .Attr("strides: list(int)")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetAllActivationModeAttrString())
+ .SetShapeFn(shape_inference::FusedConvBiasActivationShape)
+ .Doc(R"doc(
+ Computes a fused 2-D convolution, adds bias, and applies an activation function
+ on the output given 4-D `input`, 4-D `filter`, 1-D `bias` tensors and an activation mode.
+
+ input: A 4-D tensor. The dimension order is interpreted according to the value
+ of `data_format`, see below for details.
+ filter: A 4-D tensor of shape
+ `[filter_height, filter_width, in_channels, out_channels]`
+ bias: 1-D with size of the `out_channels` dimension in filter.
+ output: A 4-D tensor. The dimension order is determined by the value of
+ `data_format`, see below for details.
+ T: The data type for the elements of input, filter, bias, and output Tensors.
+ strides: 1-D tensor of length 4. The stride of the sliding window for each
+ dimension of `input`. The dimension order is determined by the value of
+ `data_format`, see below for details.
+ padding: The type of padding algorithm to use.
+ data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, height, width, channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, channels, height, width].
+ activation_mode: Specify the activation function to apply to the output tensor
+ of bias add. Currently only supports "Relu".
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/fused_conv/python/__init__.py b/tensorflow/contrib/fused_conv/python/__init__.py
new file mode 100644
index 0000000000..23d817cefb
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/python/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ops module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py
new file mode 100644
index 0000000000..a65d4bc50f
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for fused conv2d bias and activation op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import time
+
+from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def build_conv_bias_relu_graph(device, input_shape, filter_shape, strides,
+ padding, num_iters, data_format):
+ """builds a graph containing a sequence of conv2d operations.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use.
+ num_iters: number of iterations to run conv2d.
+ data_format: data format string of input, 'NHWC' and 'NCHW' are
+ supported.
+
+ Returns:
+ An array of tensors to run()
+ """
+ if data_format == "NCHW":
+ input_shape = [
+ input_shape[0], input_shape[3], input_shape[1], input_shape[2]
+ ]
+ with ops.device("/%s:0" % device):
+ inp = variables.Variable(random_ops.truncated_normal(input_shape))
+ filt = variables.Variable(random_ops.truncated_normal(filter_shape))
+ bias_shape = [filter_shape[-1]]
+ bias = variables.Variable(random_ops.truncated_normal(bias_shape))
+
+ outputs = []
+ conv2d_out = nn_ops.conv2d(
+ inp, filt, strides, padding, data_format=data_format)
+ bias_out = nn_ops.bias_add(conv2d_out, bias, data_format=data_format)
+ relu_out = nn_ops.relu(bias_out)
+ outputs.append(relu_out)
+ for _ in range(1, num_iters):
+ with ops.control_dependencies([relu_out]):
+ conv2d_out = nn_ops.conv2d(
+ inp, filt, strides, padding, data_format=data_format)
+ bias_out = nn_ops.bias_add(conv2d_out, bias, data_format=data_format)
+ relu_out = nn_ops.relu(bias_out)
+ outputs.append(relu_out)
+ return control_flow_ops.group(*outputs)
+
+
+def build_fused_conv_bias_relu_graph(device, input_shape, filter_shape, strides,
+ padding, num_iters, data_format):
+ """builds a graph containing a sequence of conv2d operations.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use.
+ num_iters: number of iterations to run conv2d.
+ data_format: data format string of input, 'NHWC' and 'NCHW' are
+ supported.
+
+ Returns:
+ An array of tensors to run()
+ """
+ if data_format == "NCHW":
+ input_shape = [
+ input_shape[0], input_shape[3], input_shape[1], input_shape[2]
+ ]
+ with ops.device("/%s:0" % device):
+ inp = variables.Variable(random_ops.truncated_normal(input_shape))
+ filt = variables.Variable(random_ops.truncated_normal(filter_shape))
+ bias_shape = [filter_shape[-1]]
+ bias = variables.Variable(random_ops.truncated_normal(bias_shape))
+
+ outputs = []
+ fused_out = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ inp,
+ filt,
+ bias,
+ strides,
+ padding,
+ data_format=data_format,
+ activation_mode="Relu")
+ outputs.append(fused_out)
+ for _ in range(1, num_iters):
+ with ops.control_dependencies([fused_out]):
+ # pylint: disable=g-line-too-long
+ fused_out = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ inp,
+ filt,
+ bias,
+ strides,
+ padding,
+ data_format=data_format,
+ activation_mode="Relu")
+ outputs.append(fused_out)
+ return control_flow_ops.group(*outputs)
+
+
+class FusedConv2DBiasActivationBenchmark(test.Benchmark):
+ """Benchmark conv2d!"""
+
+ def _run_graph(self, device, input_shape, filter_shape, strides, padding,
+ num_iters, data_format):
+ """runs the graph and print its execution time.
+
+ Args:
+ device: String, the device to run on.
+ input_shape: Shape of the input tensor.
+ filter_shape: Shape of the filter tensor.
+ strides: A list of ints. 1-D of length 4. The stride of sliding
+ window for each dimension of input.
+ padding: A string from: "SAME", "VALID". The type of padding
+ algorithm to use. num_iters: Number of iterations to run the
+ benchmark.
+ num_iters: number of iterations to run conv2d.
+ data_format: data format string of input, 'NHWC' and 'NCHW' are
+ supported.
+
+ Returns:
+ The duration of the run in seconds.
+ """
+ graph = ops.Graph()
+ with graph.as_default():
+ outputs = build_fused_conv_bias_relu_graph(device, input_shape,
+ filter_shape, strides, padding,
+ num_iters, data_format)
+ with session_lib.Session(graph=graph) as session:
+ variables.global_variables_initializer().run()
+ # warmup runs
+ session.run(outputs)
+
+ start_time = time.time()
+ session.run(outputs)
+ duration = (time.time() - start_time) / num_iters
+
+ print("%s inputshape:%s filtershape:%s strides:%s padding:%s "
+ "%d iters: %.8f sec" %
+ (device, str(input_shape).replace(" ", ""),
+ str(filter_shape).replace(" ", ""),
+ str(strides).replace(" ", ""), padding, num_iters, duration))
+ name_template = (
+ "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_"
+ "strides_{strides}_padding_{padding}")
+
+ self.report_benchmark(
+ name=name_template.format(
+ device=device,
+ inputshape=str(input_shape).replace(" ", ""),
+ filtershape=str(filter_shape).replace(" ", ""),
+ strides=str(strides).replace(" ", ""),
+ padding=padding).replace(" ", ""),
+ iters=num_iters,
+ wall_time=duration)
+
+ return duration
+
+ def benchmark_fused_conv2d_bias_activation(self):
+
+ stride = [1, 1, 1, 1]
+ paddings = ["VALID", "SAME"]
+ data_formats = ["NHWC", "NCHW"]
+
+ resnet50_input_shapes = [[64, 14, 14, 256], [64, 14, 14, 256], [
+ 64, 14, 14, 1024
+ ], [64, 55, 55, 64], [64, 28, 28, 128], [64, 28, 28, 128], [64, 55, 55, 64],
+ [64, 7, 7, 512], [64, 7, 7, 512],
+ [64, 28, 28, 512], [64, 55, 55,
+ 256], [64, 7, 7, 2048]]
+
+ resnet50_filter_shapes = [[1, 1, 256, 1024], [3, 3, 256, 256], [
+ 1, 1, 1024, 256
+ ], [1, 1, 64, 256], [1, 1, 128, 512], [3, 3, 128, 128], [3, 3, 64, 64], [
+ 3, 3, 512, 512
+ ], [1, 1, 512, 2048], [1, 1, 512, 128], [1, 1, 256, 64], [1, 1, 2048, 512]]
+
+ inception3_input_shapes = [[64, 17, 17, 768], [64, 35, 35, 96], [
+ 64, 35, 35, 288
+ ], [64, 8, 8, 384], [64, 8, 8, 384], [64, 17, 17, 192], [64, 35, 35, 64], [
+ 64, 17, 17, 192
+ ], [64, 17, 17, 160], [64, 17, 17, 160], [64, 17, 17, 768], [
+ 64, 35, 35, 256
+ ], [64, 35, 35, 48], [64, 35, 35, 192], [64, 17, 17, 128], [
+ 64, 17, 17, 160
+ ], [64, 8, 8, 448], [64, 17, 17, 128], [64, 17, 17, 768], [64, 17, 17, 160]]
+ inception3_filter_shapes = [[1, 1, 768, 192], [3, 3, 96, 96], [
+ 1, 1, 288, 64
+ ], [1, 3, 384, 384], [3, 1, 384, 384], [7, 1, 192, 192], [3, 3, 64, 96], [
+ 1, 7, 192, 192
+ ], [7, 1, 160, 160], [1, 7, 160, 160], [1, 1, 768, 160], [1, 1, 256, 64], [
+ 5, 5, 48, 64
+ ], [1, 1, 192, 64], [1, 7, 128, 128], [1, 7, 160, 192], [3, 3, 448, 384],
+ [7, 1, 128, 128], [1, 1, 768,
+ 128], [7, 1, 160, 192]]
+
+ print("fused conv2d bias activation benchmark using resnet50's shapes:")
+ for ishape, fshape in zip(resnet50_input_shapes, resnet50_filter_shapes):
+ for padding in paddings:
+ for data_format in data_formats:
+ self._run_graph("gpu", ishape, fshape, stride, padding, 80,
+ data_format)
+ print("fused conv2d bias activation benchmark using inception3's shapes:")
+ for ishape, fshape in zip(inception3_input_shapes,
+ inception3_filter_shapes):
+ for padding in paddings:
+ for data_format in data_formats:
+ self._run_graph("gpu", ishape, fshape, stride, padding, 80,
+ data_format)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
new file mode 100644
index 0000000000..41fd114f0f
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
@@ -0,0 +1,87 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tensorflow op performing fused conv2d bias_add and relu."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_fused_conv2d_bias_activation_op_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so"))
+
+
+def fused_conv2d_bias_activation(input_tensor,
+ filter_tensor,
+ bias,
+ strides,
+ padding,
+ activation_mode,
+ data_format=None,
+ name=None):
+ """Computes a fused 2-D convolution, adds bias, and applies relu.
+
+ input_tensor: A 4-D tensor. The dimension order is interpreted
+ according to the value of `data_format`, see below for details.
+ filter_tensor: A 4-D tensor of shape
+ `[filter_height, filter_width, in_channels, out_channels]`
+ bias: 1-D with size of the `out_channels` dimension in filter.
+ output: A 4-D tensor. The dimension order is determined by the value of
+ `data_format`, see below for details.
+ T: The data type for the elements of input, filter, bias, and output
+ Tensors.
+ strides: 1-D tensor of length 4. The stride of the sliding window for
+ each
+ dimension of `input`. The dimension order is determined by the value
+ of
+ `data_format`, see below for details.
+ padding: The type of padding algorithm to use.
+ data_format: Specify the data format of the input and output data. With
+ the
+ default format "NHWC", the data is stored in the order of:
+ [batch, height, width, channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, channels, height, width].
+ activation_mode: Specify the activation function to apply to the output
+ tensor
+ of bias add. Currently only supports "Relu".
+
+ Args:
+ input_tensor: A `Tensor`. Must be one of the following types: `float32`.
+ filter_tensor: A `Tensor`. Must have the same type as `input`.
+ bias: A `Tensor`. Must have the same type as `input`.
+ strides: A list of `ints`.
+ padding: A `string` from: `"SAME", "VALID"`.
+ activation_mode: A `string` from: `"Sigmoid", "Relu", "Relu6", "ReluX",
+ "Tanh", "BandPass"`.
+ data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
+ `"NHWC"`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `input`.
+ """
+ return gen_fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ input=input_tensor,
+ filter=filter_tensor,
+ bias=bias,
+ strides=strides,
+ padding=padding,
+ activation_mode=activation_mode,
+ data_format=data_format,
+ name=name)
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
new file mode 100644
index 0000000000..5d6a2fa3b8
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -0,0 +1,573 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for fused conv2d bias and activation operation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def GetShrunkInceptionShapes(shrink=10):
+ """Iterator for smaller versions of convolution shapes in 2015 Inception.
+
+ Relative to inception, each depth value is `depth // shrink`.
+
+ Args:
+ shrink: Factor to shrink each depth value by relative to Inception.
+
+ Yields:
+ Tuple (input_size, filter_size, out_size, stride, padding), the convolution
+ parameters of Inception layers.
+ """
+ input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [
+ 4, 8, 8, 2048
+ ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [
+ 4, 8, 8, 1760
+ ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [
+ 4, 17, 17, 160
+ ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
+ [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [
+ 4, 17, 17, 768
+ ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768],
+ [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [
+ 4, 35, 35, 64
+ ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [
+ 4, 35, 35, 256
+ ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [
+ 4, 35, 35, 192
+ ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]]
+ filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [
+ 1, 1, 2048, 192
+ ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384],
+ [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [
+ 1, 1, 1760, 320
+ ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [
+ 3, 3, 128, 320
+ ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [
+ 1, 3, 192, 256
+ ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [
+ 3, 3, 192, 224
+ ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [
+ 3, 1, 192, 192
+ ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [
+ 1, 3, 128, 192
+ ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [
+ 3, 1, 128, 128
+ ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [
+ 1, 1, 768, 128
+ ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [
+ 3, 3, 64, 96
+ ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64],
+ [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [
+ 1, 1, 192, 64
+ ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64,
+ 64], [1, 1, 24, 64]]
+ out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [
+ 4, 8, 8, 384
+ ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [
+ 4, 8, 8, 192
+ ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [
+ 4, 17, 17, 192
+ ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [
+ 4, 17, 17, 256
+ ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [
+ 4, 17, 17, 160
+ ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [
+ 4, 17, 17, 256
+ ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [
+ 4, 17, 17, 128
+ ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [
+ 4, 35, 35, 64
+ ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96],
+ [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48],
+ [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]]
+ strides = [
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1
+ ]
+ # Shrink sizes to make the test faster
+ for i in input_sizes:
+ i[3] //= shrink
+ for f in filter_sizes:
+ f[2] //= shrink
+ f[3] //= shrink
+ for o in out_sizes:
+ o[3] //= shrink
+ # pylint: disable=invalid-name
+ VALID = "VALID"
+ SAME = "SAME"
+ # pylint: enable=invalid-name
+ paddings = [
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, VALID, VALID, VALID
+ ]
+ for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
+ paddings):
+ yield i, f, o, s, p
+
+
+def GetTestConfigs():
+ """Get all the valid tests configs to run.
+
+ Returns:
+ all the valid test configs as tuples of data_format and use_gpu.
+ """
+ test_configs = [("NCHW", True), ("NHWC", True)]
+ return test_configs
+
+
+class FusedConv2DBiasActivationTest(test.TestCase):
+
+ def _DtypesToTest(self, use_gpu):
+ return [dtypes.float32]
+
+ def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias,
+ strides, padding, activation_mode, data_format,
+ dtype):
+ """Verifies the output values of the convolution function.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ bias: 1-D bias tensor of length output_depth.
+ strides: Stride: [col_stride, row_stride]
+ padding: Padding type.
+ activation_mode: Activation mode.
+ data_format: Format of the data tensors.
+ dtype: Data type for inputs and outputs.
+ Returns:
+ Symbolic tensor value and reference value that can be used to
+ execute the computation and verify the results.
+ """
+ input_size = np.prod(tensor_in_sizes)
+ filter_size = np.prod(filter_in_sizes)
+ bias_size = filter_in_sizes[-1] # equals to output depth
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, input_size + 1)]
+ x2 = [f * 1.0 for f in range(1, filter_size + 1)]
+ # This is to guarantee that there is always negative values after
+ # bias add so that we can test whether relu works correctly.
+ x3 = bias
+ with self.test_session(use_gpu=True):
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
+ t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
+ strides = [1] + strides + [1]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ strides = test_util.NHWCToNCHW(strides)
+ output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ t1,
+ t2,
+ t3,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation_mode=activation_mode)
+ ref_conv_output = nn_ops.conv2d(
+ t1, t2, strides=strides, padding=padding, data_format=data_format)
+ ref_bias_output = nn_ops.bias_add(
+ ref_conv_output, t3, data_format=data_format)
+ ref_output = nn_ops.relu(ref_bias_output)
+ if data_format == "NCHW":
+ output = test_util.NCHWToNHWC(output)
+ ref_output = test_util.NCHWToNHWC(ref_output)
+
+ return output, ref_output
+
+ def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
+ padding):
+ """Verifies that CPU and GPU produce the same values.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ conv_strides: [row_stride, col_stride] for the convolution;
+ padding: Padding type.
+ """
+ x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
+ x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
+ x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
+
+ def _SetupVal(data_format, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes)
+ t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
+ strides = [1] + conv_strides + [1]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ strides = test_util.NHWCToNCHW(strides)
+ output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ t1,
+ t2,
+ t3,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation_mode="Relu")
+
+ if data_format == "NCHW":
+ output = test_util.NCHWToNHWC(output)
+ return output
+
+ tensors = []
+ for (data_format, use_gpu) in GetTestConfigs():
+ tensors.append(_SetupVal(data_format, use_gpu))
+ with self.test_session() as sess:
+ values = sess.run(tensors)
+ for i in range(1, len(values)):
+ self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
+
+ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
+ padding):
+ tensors = []
+ ref_tensors = []
+ for (data_format, use_gpu) in GetTestConfigs():
+ for dtype in self._DtypesToTest(use_gpu):
+ result, expected = self._SetupValuesForDevice(
+ tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu",
+ data_format, dtype)
+ tensors.append(result)
+ ref_tensors.append(expected)
+ with self.test_session() as sess:
+ values = sess.run(tensors)
+ ref_values = sess.run(ref_tensors)
+ for i in range(len(tensors)):
+ conv = tensors[i]
+ value = values[i]
+ ref_value = ref_values[i]
+ print("expected = ", ref_value)
+ print("actual = ", value)
+ tol = 1e-5
+ if value.dtype == np.float16:
+ tol = 1e-3
+ self.assertAllClose(
+ np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
+ self.assertShapeEqual(value, conv)
+
+ def testConv2D1x1Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D1x1Filter test.")
+ return
+ # expected_output = [
+ # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0,
+ # 86.0, 43.0, 165.0, 131.0, 97.0
+ # ]
+ medians = [-45.0, -130.0, -215.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ bias=medians,
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2DEmpty(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DEmpty test.")
+ return
+ # expected_output = []
+ self._VerifyValues(
+ tensor_in_sizes=[0, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ bias=[0.0, 0.0, 0.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D2x2Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2Filter test.")
+ return
+ # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2500.0, -2500.0, -2500.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D1x2Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D1x2Filter test.")
+ return
+ # expected_output = [
+ # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0
+ # ]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 2, 3, 3],
+ bias=[-500.0, -500.0, -500.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D2x2FilterStride2(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride2 test.")
+ return
+ # expected_output = [0.0, 67.0, 163.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2300.0, -2300.0, -2300.0],
+ strides=[2, 2],
+ padding="VALID")
+
+ def testConv2D2x2FilterStride2Same(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.")
+ return
+ # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2300.0, -1000.0, -1000.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ def testConv2D2x2FilterStride1x2(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.")
+ return
+ # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 3, 6, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-90.0],
+ strides=[1, 2],
+ padding="VALID")
+
+ def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.")
+ return
+ # expected_output = [0, 0, 175, 205]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 7, 7, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-100.0],
+ strides=[3, 3],
+ padding="VALID")
+
+ def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.")
+ return
+ # expected = [0, 0, 2, 4]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 3, 3, 1],
+ filter_in_sizes=[1, 1, 1, 1],
+ bias=[-5.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ # expected = [0, 0, 4, 6]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[1, 1, 1, 1],
+ bias=[-5.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ # expected = [4, 0, 1, 0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-40.0],
+ strides=[3, 3],
+ padding="SAME")
+
+ def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.")
+ return
+ # expected = [0, 5]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 2, 1],
+ filter_in_sizes=[2, 2, 1, 2],
+ bias=[-50.0, -55.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ # expected = [0, 2, 282, 322]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 8, 8, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-200.0],
+ strides=[4, 4],
+ padding="SAME")
+
+ def testShapeFunctionEdgeCases(self):
+ # All shapes unknown.
+ c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+ self.assertEqual([None, None, None, None], c1.get_shape().as_list())
+
+ # Incorrect input shape.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ # Incorrect filter shape.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ # Depth mismatch.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ def testOpEdgeCases(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping OpEdgeCases tests.")
+ return
+ with self.test_session() as sess:
+ # Illegal strides.
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "strides in the batch and depth"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[2, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu"))
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "strides in the batch and depth"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 2],
+ padding="SAME",
+ activation_mode="Relu"))
+
+ # Illegal activation mode.
+ with self.assertRaisesRegexp(ValueError,
+ "Op passed string 'Tanh' not in:"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Tanh"))
+
+ # Filter larger than input.
+ with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[20, 21, 3, 2]),
+ array_ops.placeholder(dtypes.float32, shape=[2]),
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ activation_mode="Relu"))
+ with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[21, 20, 3, 2]),
+ array_ops.placeholder(dtypes.float32, shape=[2]),
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ activation_mode="Relu"))
+
+
+def GetInceptionFwdTest(input_size, filter_size, stride, padding,
+ gpu_only=True):
+
+ def Test(self):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping InceptionFwd %s", (input_size, filter_size,
+ stride, padding))
+ return
+ tf_logging.info("Testing InceptionFwd %s", (input_size, filter_size, stride,
+ padding))
+ self._CompareFwdValues(input_size, filter_size, [stride, stride], padding)
+
+ return Test
+
+
+if __name__ == "__main__":
+ for index, (input_size_, filter_size_, output_size_, stride_,
+ padding_) in enumerate(GetShrunkInceptionShapes()):
+ setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index),
+ GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
+
+ # TODO(b/35359731)
+ # Fwd, BckInput, and BackFilter to test that for certain input parameter
+ # set, winograd nonfused algorithm will be excluded from conv autotune. If
+ # in such case, winograd nonfused algorithm is added as one option of the
+ # conv autotune, and cuDNN version is smaller than 7, the following tests
+ # will fail.
+ ishape = [1, 400, 400, 1]
+ fshape = [1, 1, 1, 256]
+ oshape = [1, 400, 400, 256]
+ setattr(FusedConv2DBiasActivationTest,
+ "testInceptionFwd_No_Winograd_Nonfused",
+ GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
+ test.main()