aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-24 00:00:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-24 00:04:32 -0700
commit91b00a110feb83b7307ea9c280142007090f3cd9 (patch)
tree65b55fce64f9685476deca665686588def6e95b0 /tensorflow/contrib/fused_conv
parent2f531b54ce51fdb9fc2b055548e534a13624ea93 (diff)
Automated g4 rollback of changelist 166276461
PiperOrigin-RevId: 166305887
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r--tensorflow/contrib/fused_conv/BUILD13
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc660
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h31
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h74
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc77
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py107
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py289
7 files changed, 385 insertions, 866 deletions
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 9b34cf1bdb..f5d21278db 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -60,14 +60,12 @@ tf_kernel_library(
srcs = [
"kernels/fused_conv2d_bias_activation_op.cc",
"kernels/fused_conv2d_bias_activation_op.h",
- "kernels/fused_conv_ops_gpu.h",
],
prefix = "fused_conv2d_bias_activation_op",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
- "//tensorflow/core:stream_executor",
"//tensorflow/core/kernels:bounds_check_lib",
"//tensorflow/core/kernels:conv_2d_hdrs",
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
@@ -83,7 +81,6 @@ tf_custom_op_library(
srcs = [
"kernels/fused_conv2d_bias_activation_op.cc",
"kernels/fused_conv2d_bias_activation_op.h",
- "kernels/fused_conv_ops_gpu.h",
"ops/fused_conv2d_bias_activation_op.cc",
],
deps = [
@@ -97,8 +94,12 @@ tf_custom_op_library(
)
tf_gen_op_libs(
- op_lib_names = ["fused_conv2d_bias_activation_op"],
- deps = ["//tensorflow/core:lib_proto_parsing"],
+ op_lib_names = [
+ "fused_conv2d_bias_activation_op",
+ ],
+ deps = [
+ "//tensorflow/core:lib_proto_parsing",
+ ],
)
tf_gen_op_wrapper_py(
@@ -108,7 +109,7 @@ tf_gen_op_wrapper_py(
cuda_py_test(
name = "fused_conv2d_bias_activation_op_test",
- size = "large",
+ size = "small",
srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
additional_deps = [
":fused_conv_py",
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index fcdf03b596..dc0701b234 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -31,8 +31,8 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#if GOOGLE_CUDA
@@ -40,72 +40,38 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/activation_mode.h"
#endif // GOOGLE_CUDA
-
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename T>
-struct RawType {
- using type = T;
-};
+template <typename Device, typename T>
+struct LaunchConvOp;
-template <>
-struct RawType<qint8> {
- using type = int8;
-};
-
-// T is the element type of the conv_input, filter and side_input tensors.
-// BiasType is the element type of the bias tensor, which can be different.
-// ScaleType is the type used for conv_input_scale, side_input_scale.
-template <typename Device, typename T, typename BiasType, typename ScaleType>
+template <typename Device, typename T>
class FusedConv2DBiasActivationOp : public OpKernel {
public:
explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
: OpKernel(context) {
- string data_format_str, filter_format_str;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
- OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES_OK(context,
- context->GetAttr("filter_format", &filter_format_str));
OP_REQUIRES(context,
- FilterFormatFromString(filter_format_str, &filter_format_),
- errors::InvalidArgument("Invalid filter format"));
-
- std::vector<int32> strides;
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
- OP_REQUIRES(context, strides.size() == 4,
+ (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW),
+ errors::InvalidArgument("Current implementation only supports "
+ "NHWC and NCHW data formats."));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 dimensions"));
-
- stride_rows_ = GetTensorDim(strides, data_format_, 'H');
- stride_cols_ = GetTensorDim(strides, data_format_, 'W');
- OP_REQUIRES(
- context,
- (GetTensorDim(strides, data_format_, 'N') == 1 &&
- GetTensorDim(strides, data_format_, 'C') == 1),
- errors::InvalidArgument("Convolutional strides are not supported in "
- "the batch or depth dimensions."));
-
- // Note: Only NCHW_VECT_C format is supported for int8.
- // This is because it is expected to be the fastest, and our previous tests
- // found cudnn 6 does not fully support the other formats for int8 mode.
OP_REQUIRES(
context,
- (std::is_same<T, qint8>::value == (data_format_ == FORMAT_NCHW_VECT_C)),
- errors::InvalidArgument(
- "qint8 should be used with data_format NCHW_VECT_C."));
-
- OP_REQUIRES(context,
- (std::is_same<T, qint8>::value ==
- (filter_format_ == FORMAT_OIHW_VECT_I)),
- errors::InvalidArgument(
- "qint8 should be used with filter_format OIHW_VECT_I."));
-
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_));
- eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_);
+ (GetTensorDim(strides_, data_format_, 'N') == 1 &&
+ GetTensorDim(strides_, data_format_, 'C') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
string activation_mode_str;
OP_REQUIRES_OK(context,
context->GetAttr("activation_mode", &activation_mode_str));
@@ -113,111 +79,130 @@ class FusedConv2DBiasActivationOp : public OpKernel {
&activation_mode_));
OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
errors::InvalidArgument("Current implementation only supports "
- "RELU as the activation function."));
+ "relu as the activation mode."));
cudnn_use_autotune_ = CudnnUseAutotune();
- float conv_input_scale_flt, side_input_scale_flt;
- OP_REQUIRES_OK(context,
- context->GetAttr("conv_input_scale", &conv_input_scale_flt));
- OP_REQUIRES_OK(context,
- context->GetAttr("side_input_scale", &side_input_scale_flt));
- conv_input_scale_ = conv_input_scale_flt;
- side_input_scale_ = side_input_scale_flt;
- }
-
- Status CheckShape(const Tensor& tensor, const string& tensor_name) {
- const int num_dims = tensor.dims();
- for (int i = 0; i < num_dims; i++) {
- if (!FastBoundsCheck(tensor.dim_size(i),
- std::numeric_limits<int32>::max())) {
- return errors::InvalidArgument(tensor_name, " dimension ", i,
- " too large");
- }
- }
- // If there is a 5th dimension it is the VECT_C or VECT_I dimension.
- if (num_dims == 5 && tensor.dim_size(4) != 4) {
- return errors::InvalidArgument("The last dimension of ", tensor_name,
- " must be of size 4 for qint8.");
- }
- return Status::OK();
}
void Compute(OpKernelContext* context) override {
- // The conv_input tensor is one of the following formats:
- // NHWC, NCHW, NCHW_VECT_C.
- const Tensor& conv_input = context->input(0);
- OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
+ // Input tensor is one of the following shapes:
+ // [ batch, in_rows, in_cols, in_depth ] (for NHWC data format)
+ // [ batch, in_depth, in_rows, in_cols ] (for NCHW data format)
+ const Tensor& input = context->input(0);
- // The filter tensor is one of the following formats:
- // HWIO, OIHW, OIHW_VECT_I.
+ // Input filter is of the following dimensions:
+ // [ filter_rows, filter_cols, in_depth, out_depth ]
const Tensor& filter = context->input(1);
- OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
- // Input bias is a 1-D tensor, with size matching output depth.
+ // Input bias is a 1-D tensor the size of the last
+ // dimension of Output tensor
const Tensor& bias = context->input(2);
- OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
- // If side_input_scale != 0, then side_input is not ignored and
- // has the same type and dimensions as the output.
- const Tensor& side_input = context->input(3);
- if (side_input_scale_ != 0) {
- OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
+ // For 2D convolution, there should be 4 dimensions.
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+
+ // Bias should be a 1-D tensor.
+ OP_REQUIRES(context, bias.dims() == 1,
+ errors::InvalidArgument("bias must be 1-dimensional: ",
+ bias.shape().DebugString()));
+
+ for (int i = 0; i < 4; i++) {
+ OP_REQUIRES(context,
+ FastBoundsCheck(filter.dim_size(i),
+ std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("filter dimension too large"));
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(input.dim_size(i), std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("input dimension too large"));
}
- // TODO(pauldonnelly): Switch to a more efficient mechanism to access
- // dimension indexes and per-dimension attributes.
- const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H');
- const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W');
- const int32 output_depth = GetFilterDim(filter, filter_format_, 'O');
-
- const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N');
- const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H');
- const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W');
-
- int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows,
- stride_rows_, padding_type_,
- &output_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols,
- stride_cols_, padding_type_,
- &output_cols, &pad_cols));
- // Initialize the output tensor shape according to data_format_
- TensorShape output_shape = ShapeFromFormat(
- data_format_, batch_size, output_rows, output_cols, output_depth);
+ // The last dimension for input is in_depth. It must be the same as the
+ // filter's in_depth.
+ const int64 in_depth = GetTensorDim(input, data_format_, 'C');
+ OP_REQUIRES(context, in_depth == filter.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth,
+ " vs ", filter.dim_size(2)));
+
+ // The last dimension for filter is out_depth.
+ const int32 out_depth = static_cast<int32>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
+ const int32 input_rows = static_cast<int32>(input_rows_raw);
+ const int32 filter_rows = static_cast<int32>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
+ const int32 input_cols = static_cast<int32>(input_cols_raw);
+ const int32 filter_cols = static_cast<int32>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
+ const int32 batch = static_cast<int32>(batch_raw);
+
+ // For now we take the stride from the second and third dimensions only (we
+ // do not support striding on the batch or depth dimension).
+ const int32 stride_rows =
+ static_cast<int32>(GetTensorDim(strides_, data_format_, 'H'));
+ const int32 stride_cols =
+ static_cast<int32>(GetTensorDim(strides_, data_format_, 'W'));
+ const int32 bias_size = static_cast<int32>(bias.dim_size(0));
+
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+ // Output tensor is of the following dimensions:
+ // [ in_batch, out_rows, out_cols, out_depth ]
+ TensorShape out_shape =
+ ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // Bias size should be the same as the size of the channel dimension of
+ // output.
+ OP_REQUIRES(context, bias_size == out_depth,
+ errors::InvalidArgument(
+ "bias size should equal the channel "
+ "dimension size of output. bias shape: ",
+ bias.shape().DebugString() +
+ ", output shape: " + output->shape().DebugString()));
- VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = "
- << conv_input_cols << ", conv_input_rows = " << conv_input_rows
+ VLOG(2) << "FusedConv2DBiasActivation: in_depth = " << in_depth
+ << ", input_cols = " << input_cols
<< ", filter_cols = " << filter_cols
+ << ", input_rows = " << input_rows
<< ", filter_rows = " << filter_rows
- << ", stride_cols = " << stride_cols_
- << ", stride_rows = " << stride_rows_
- << ", output_depth = " << output_depth
- << ", output_cols = " << output_cols
- << ", output_rows = " << output_rows
- << ", output_shape.num_elements = " << output_shape.num_elements();
+ << ", stride_rows = " << stride_rows
+ << ", stride_cols = " << stride_cols
+ << ", bias_size = " << bias_size << ", out_depth = " << out_depth;
// If there is nothing to compute, return.
- if (output_shape.num_elements() == 0) {
+ if (out_shape.num_elements() == 0) {
return;
}
-
- launcher_.launch(context, cudnn_use_autotune_, conv_input,
- conv_input_scale_, filter, stride_rows_, stride_cols_,
- eigen_padding_type_, side_input, side_input_scale_, bias,
- activation_mode_, data_format_, filter_format_, output);
+ launcher_.launch(context, cudnn_use_autotune_, input, filter, stride_rows,
+ stride_cols, bias, activation_mode_,
+ BrainPadding2EigenPadding(padding_), data_format_, output);
}
private:
- int32 stride_rows_, stride_cols_;
- Padding padding_type_;
- Eigen::PaddingType eigen_padding_type_;
+ std::vector<int32> strides_;
+ Padding padding_;
ActivationMode activation_mode_;
TensorFormat data_format_;
- FilterTensorFormat filter_format_;
- ScaleType conv_input_scale_;
- ScaleType side_input_scale_;
- LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
+ LaunchFusedConv2DBiasActivationOp<Device, T> launcher_;
bool cudnn_use_autotune_;
TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp);
@@ -226,71 +211,67 @@ class FusedConv2DBiasActivationOp : public OpKernel {
#if GOOGLE_CUDA
namespace dnn = ::perftools::gputools::dnn;
+dnn::ActivationMode BrainActivationMode2CudnnActivationMode(
+ ActivationMode activation_mode) {
+ switch (activation_mode) {
+ case ActivationMode::SIGMOID:
+ return dnn::ActivationMode::kSigmoid;
+ case ActivationMode::RELU:
+ return dnn::ActivationMode::kRelu;
+ case ActivationMode::RELUX:
+ return dnn::ActivationMode::kReluX;
+ case ActivationMode::RELU6:
+ return dnn::ActivationMode::kRelu6;
+ case ActivationMode::TANH:
+ return dnn::ActivationMode::kTanh;
+ case ActivationMode::BANDPASS:
+ return dnn::ActivationMode::kBandPass;
+ }
+ // Prevent compiler warning about missing return
+ return dnn::ActivationMode::kRelu;
+}
+
// A dummy type to group forward convolution autotune results together.
struct ConvBiasActivationAutoTuneGroup {
static string name() { return "ConvBiasActivation"; }
};
-typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, FusedConvParameters,
- dnn::AlgorithmConfig>
+typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, ConvParameters,
+ perftools::gputools::dnn::AlgorithmConfig>
AutoTuneConvBiasActivation;
-// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it
-// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions.
-template <typename T, size_t NDIMS>
-Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor,
- int batch_size, int rows, int cols, int depth,
- Tensor* transformed_tensor, const Tensor** result) {
- TensorShape nchw_shape =
- ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth);
- if (depth > 1) {
- TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
- transformed_tensor));
- functor::NHWCToNCHW<GPUDevice, T, NDIMS>()(
- ctx->eigen_device<GPUDevice>(), nhwc_tensor.tensor<T, NDIMS>(),
- transformed_tensor->tensor<T, NDIMS>());
- } else {
- // If depth <= 1, then just reshape.
- CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape));
- }
- *result = transformed_tensor;
- return Status::OK();
-}
-
-template <typename T, typename BiasType, typename ScaleType>
-void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
- launch(OpKernelContext* ctx, bool cudnn_use_autotune,
- const Tensor& conv_input_param, ScaleType conv_input_scale,
- const Tensor& filter_param, int32 row_stride, int32 col_stride,
- const Eigen::PaddingType& padding, const Tensor& side_input_param,
- ScaleType side_input_scale, const Tensor& bias,
- ActivationMode activation_mode, TensorFormat data_format,
- FilterTensorFormat filter_format, Tensor* output_param) {
+template <typename T>
+void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
+ OpKernelContext* ctx, bool cudnn_use_autotune, const Tensor& input_param,
+ const Tensor& filter, int32 row_stride, int32 col_stride,
+ const Tensor& bias, const ActivationMode& activation_mode,
+ const Eigen::PaddingType& padding, TensorFormat data_format,
+ Tensor* output) {
+ using perftools::gputools::dnn::AlgorithmConfig;
+ using perftools::gputools::dnn::AlgorithmType;
+ using perftools::gputools::dnn::ProfileResult;
+ using perftools::gputools::dnn::kDefaultAlgorithm;
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+ Tensor input = input_param;
+
+ perftools::gputools::dnn::ActivationMode cudnn_activation_mode =
+ BrainActivationMode2CudnnActivationMode(activation_mode);
+
// TODO(yangzihao): refactor all the complicated/duplicated code in regular
// conv ops to a shared conv utility.
-
- // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I here.
- constexpr int rank = std::is_same<T, qint8>::value ? 5 : 4;
- constexpr int vect = std::is_same<T, qint8>::value ? 4 : 1;
-
- const int batch_size = GetTensorDim(conv_input_param, data_format, 'N');
- int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H');
- int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W');
-
- const int conv_input_depth =
- GetTensorDim(conv_input_param, data_format, 'C') * vect;
- const int output_rows = GetTensorDim(*output_param, data_format, 'H');
- const int output_cols = GetTensorDim(*output_param, data_format, 'W');
- const int output_depth = GetFilterDim(filter_param, filter_format, 'O');
- const int filter_rows = GetFilterDim(filter_param, filter_format, 'H');
- const int filter_cols = GetFilterDim(filter_param, filter_format, 'W');
-
- int padding_rows = 0;
- int padding_cols = 0;
- const Tensor* conv_input = &conv_input_param;
- Tensor maybe_padded_conv_input;
+ int32 padding_rows = 0;
+ int32 padding_cols = 0;
+ const int64 in_batch = GetTensorDim(input, data_format, 'N');
+ int64 in_rows = GetTensorDim(input, data_format, 'H');
+ int64 in_cols = GetTensorDim(input, data_format, 'W');
+ const int64 in_depths = GetTensorDim(input, data_format, 'C');
+ const int64 out_batch = GetTensorDim(*output, data_format, 'N');
+ const int64 out_rows = GetTensorDim(*output, data_format, 'H');
+ const int64 out_cols = GetTensorDim(*output, data_format, 'W');
+ const int64 out_depths = GetTensorDim(*output, data_format, 'C');
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
if (padding == Eigen::PADDING_SAME) {
// Total padding on rows and cols is
// Pr = (R' - 1) * S + Kr - R
@@ -300,146 +281,114 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
// we pad more on the right and bottom than on the top and left.
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols);
- const int padding_rows_parity = padding_rows & 1;
- const int padding_cols_parity = padding_cols & 1;
- if ((padding_rows_parity | padding_cols_parity) != 0) {
+ padding_rows =
+ std::max<int32>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+ padding_cols =
+ std::max<int32>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
+ const int rows_parity = padding_rows & 1;
+ const int cols_parity = padding_cols & 1;
+ if ((rows_parity | cols_parity) != 0) {
Tensor transformed_input;
- const int new_conv_input_rows = conv_input_rows + padding_rows_parity;
- const int new_conv_input_cols = conv_input_cols + padding_cols_parity;
-
+ int64 new_in_rows = in_rows + rows_parity;
+ int64 new_in_cols = in_cols + cols_parity;
OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format, batch_size, new_conv_input_rows,
- new_conv_input_cols, conv_input_depth),
- &maybe_padded_conv_input));
-
- functor::PadInput<GPUDevice, T, int, rank>()(
- ctx->eigen_device<GPUDevice>(),
- To32Bit(conv_input_param.tensor<T, rank>()), {{0, 0}},
- {{padding_rows_parity, padding_cols_parity}},
- To32Bit(maybe_padded_conv_input.tensor<T, rank>()), data_format);
-
- conv_input = &maybe_padded_conv_input;
- conv_input_rows = new_conv_input_rows;
- conv_input_cols = new_conv_input_cols;
+ ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format, in_batch, new_in_rows,
+ new_in_cols, in_depths),
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
+ {{0, 0}}, {{rows_parity, cols_parity}},
+ To32Bit(transformed_input.tensor<T, 4>()), data_format);
+
+ input = transformed_input;
+ in_rows = new_in_rows;
+ in_cols = new_in_cols;
}
}
- Tensor maybe_transformed_conv_input, maybe_transformed_side_input;
- Tensor maybe_transformed_output;
- const Tensor* side_input = &side_input_param;
- Tensor* output = output_param;
-
- // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I here.
- if (!std::is_same<T, qint8>::value && (data_format == FORMAT_NHWC)) {
- OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW<T, rank>(
- ctx, *conv_input, batch_size, conv_input_rows,
- conv_input_cols, conv_input_depth,
- &maybe_transformed_conv_input, &conv_input)));
- if (side_input_scale != 0) {
- OP_REQUIRES_OK(
- ctx, (TransformNHWCToNCHW<T, rank>(
- ctx, side_input_param, batch_size, output_rows, output_cols,
- output_depth, &maybe_transformed_side_input, &side_input)));
- }
- if (output_depth > 1) {
- // Allocate a tensor for the NCHW output of the kernel and point output
- // to it. Afterwards, we will transform it to NHWC while copying back to
- // 'output_param'.
- TensorShape nchw_shape = ShapeFromFormat(
- FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth);
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
- &maybe_transformed_output));
- output = &maybe_transformed_output;
+ if (data_format == FORMAT_NHWC) {
+ // Convert the input tensor from NHWC to NCHW.
+ TensorShape nchw_shape =
+ ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
+ if (in_depths > 1) {
+ Tensor transformed_input;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ nchw_shape, &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(input).tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
+ input = transformed_input;
+ } else {
+ // If depth <= 1, then just reshape.
+ CHECK(input.CopyFrom(input, nchw_shape));
}
}
- // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I here.
- constexpr auto data_layout = std::is_same<T, qint8>::value
- ? dnn::DataLayout::kBatchDepthYX4
- : dnn::DataLayout::kBatchDepthYX;
- constexpr auto filter_layout = std::is_same<T, qint8>::value
- ? dnn::FilterLayout::kOutputInputYX4
- : dnn::FilterLayout::kOutputInputYX;
-
- dnn::BatchDescriptor conv_input_desc;
- conv_input_desc.set_count(batch_size)
- .set_feature_map_count(conv_input_depth)
- .set_height(conv_input_rows)
- .set_width(conv_input_cols)
- .set_layout(data_layout);
- dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(filter_rows)
- .set_input_filter_width(filter_cols)
- .set_input_feature_map_count(conv_input_depth)
- .set_output_feature_map_count(output_depth)
- .set_layout(filter_layout);
- dnn::BatchDescriptor side_input_desc;
- side_input_desc.set_count(batch_size)
- .set_height(output_rows)
- .set_width(output_cols)
- .set_feature_map_count(output_depth)
- .set_layout(data_layout);
- dnn::BatchDescriptor bias_desc;
- bias_desc.set_count(1)
- .set_height(1)
- .set_width(1)
- .set_feature_map_count(output_depth)
- .set_layout(dnn::DataLayout::kBatchDepthYX);
- dnn::BatchDescriptor output_desc;
- output_desc.set_count(batch_size)
- .set_height(output_rows)
- .set_width(output_cols)
- .set_feature_map_count(output_depth)
- .set_layout(data_layout);
- dnn::ConvolutionDescriptor conv_desc;
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(in_batch)
+ .set_feature_map_count(in_depths)
+ .set_height(in_rows)
+ .set_width(in_cols)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(out_batch)
+ .set_height(out_rows)
+ .set_width(out_cols)
+ .set_feature_map_count(out_depths)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter.dim_size(0))
+ .set_input_filter_width(filter.dim_size(1))
+ .set_input_feature_map_count(filter.dim_size(2))
+ .set_output_feature_map_count(filter.dim_size(3));
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
- Tensor maybe_transformed_filter;
- const Tensor* filter;
- if (std::is_same<T, qint8>::value) {
- // We have already checked filter is OIHW_VECT_I in the constructor.
- filter = &filter_param;
- } else if (filter_format == FORMAT_HWIO) {
- // Shuffle filter tensor from HWIO to OIHW:
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFilterFormat(
- FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
- &maybe_transformed_filter));
- functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()),
- To32Bit(maybe_transformed_filter.tensor<T, 4>()));
- filter = &maybe_transformed_filter;
- }
-
- auto conv_input_ptr =
- AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
- conv_input->template flat<T>().data()),
- conv_input->template flat<T>().size());
+ // Shuffles a filter tensor from:
+ // [<spatial_dims>, in, out]
+ // to:
+ // [out, in, <spatial_dims>]
+ // TODO(yangzihao): Support a data layout tag for the filter weights, and only
+ // do the transform if the weights are not already in the correct layout.
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({filter.dim_size(3), filter.dim_size(2),
+ filter.dim_size(0), filter.dim_size(1)}),
+ &transformed_filter));
+
+ functor::TransformFilter<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ To32Bit(transformed_filter.tensor<T, 4>()));
+
+ Tensor transformed_output;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
+ out_cols, out_depths),
+ &transformed_output));
+
+ auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
auto filter_ptr =
- AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
- filter->template flat<T>().data()),
- filter->template flat<T>().size());
- auto side_input_ptr =
- AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
- side_input->template flat<T>().data()),
- side_input->template flat<T>().size());
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
auto output_ptr =
- AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
- output->template flat<T>().data()),
- output->template flat<T>().size());
- auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(),
- bias.template flat<BiasType>().size());
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+
+ auto bias_ptr = AsDeviceMemory(bias.template flat<T>().data(),
+ bias.template flat<T>().size());
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
@@ -447,42 +396,38 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
);
int device_id = stream->parent()->device_ordinal();
- FusedConvParameters fused_conv_parameters = {
- batch_size,
- conv_input_depth,
- {{conv_input_rows, conv_input_cols}},
- output_depth,
- {{filter_rows, filter_cols}},
+ DataType dtype = input.dtype();
+ ConvParameters conv_parameters = {
+ in_batch,
+ in_depths,
+ {{in_rows, in_cols}},
+ out_depths,
+ {{patch_rows, patch_cols}},
{{row_stride, col_stride}},
{{padding_rows, padding_cols}},
- conv_input->dtype(),
+ dtype,
device_id,
- (side_input_scale != 0),
- activation_mode,
};
- dnn::AlgorithmConfig algorithm_config;
+ AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
- fused_conv_parameters, &algorithm_config)) {
- std::vector<dnn::AlgorithmType> algorithms;
+ conv_parameters, &algorithm_config)) {
+ std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
- fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(),
- &algorithms));
- dnn::ProfileResult best_result;
- dnn::ProfileResult best_result_no_scratch;
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
+ ProfileResult best_result;
+ ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- dnn::ProfileResult profile_result;
+ ProfileResult profile_result;
bool cudnn_launch_status =
stream
- ->ThenFusedConvolveWithAlgorithm(
- conv_input_desc, conv_input_ptr, conv_input_scale,
- filter_desc, filter_ptr, conv_desc, side_input_ptr,
- side_input_scale, bias_desc, bias_ptr,
- dnn::ActivationMode::kRelu, output_desc, &output_ptr,
- &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
&profile_result)
.ok();
if (cudnn_launch_status) {
@@ -509,53 +454,42 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
- AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
+ AutoTuneConvBiasActivation::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
stream
- ->ThenFusedConvolveWithAlgorithm(
- conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
- filter_ptr, conv_desc, side_input_ptr, side_input_scale,
- bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc,
- &output_ptr, &scratch_allocator, algorithm_config,
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
+ &scratch_allocator, algorithm_config,
/*output_profile_result=*/nullptr)
.ok();
if (!cudnn_launch_status) {
- ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(",
- conv_input->shape().DebugString(),
- ") filter shape(",
- filter->shape().DebugString(), ")"));
+ ctx->SetStatus(errors::Internal(
+ "cuDNN launch failure : input shape(", input.shape().DebugString(),
+ ") filter shape(", filter.shape().DebugString(), ")"));
}
- // Convert the output tensor back from NCHW to NHWC if necessary.
- if (!std::is_same<T, qint8>::value && (data_format == FORMAT_NHWC) &&
- (output_depth > 1)) {
+ // Convert the output tensor back from NCHW to NHWC.
+ if (data_format == FORMAT_NHWC) {
functor::NCHWToNHWC<GPUDevice, T, 4>()(
ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor*>(output)->tensor<T, 4>(),
- output_param->tensor<T, 4>());
+ const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
+ output->tensor<T, 4>());
+ } else {
+ *output = transformed_output;
}
}
// Registration of the GPU implementations.
-
-REGISTER_KERNEL_BUILDER(
- Name("FusedConv2DBiasActivation")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .TypeConstraint<float>("Tbias"),
- FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
-
-REGISTER_KERNEL_BUILDER(
- Name("FusedConv2DBiasActivation")
- .Device(DEVICE_GPU)
- .TypeConstraint<qint8>("T")
- .TypeConstraint<float>("Tbias"),
- FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
+REGISTER_KERNEL_BUILDER(Name("FusedConv2DBiasActivation")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ FusedConv2DBiasActivationOp<GPUDevice, float>);
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
index 7534f5797c..d71b26cf1d 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
@@ -24,7 +24,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h"
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -33,30 +33,27 @@ namespace tensorflow {
// Forward declaration.
class OpKernelContext;
-template <typename Device, typename T, typename BiasType, typename ScaleType>
+template <typename Device, typename T>
class LaunchFusedConv2DBiasActivationOp {
public:
void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
- const Tensor& conv_input, ScaleType conv_input_scale,
- const Tensor& filter, int32 row_stride, int32 col_stride,
- const Eigen::PaddingType& padding, const Tensor& side_input,
- ScaleType side_input_scale, const Tensor& bias,
- ActivationMode activation_mode, TensorFormat data_format,
- FilterTensorFormat filter_format, Tensor* output);
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Tensor& bias,
+ const ActivationMode& activation_mode,
+ const Eigen::PaddingType& padding, TensorFormat data_format,
+ Tensor* output);
};
#ifdef GOOGLE_CUDA
-template <typename T, typename BiasType, typename ScaleType>
-class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T, BiasType,
- ScaleType> {
+template <typename T>
+class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T> {
public:
void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
- const Tensor& conv_input, ScaleType conv_input_scale,
- const Tensor& filter, int32 row_stride, int32 col_stride,
- const Eigen::PaddingType& padding, const Tensor& side_input,
- ScaleType side_input_scale, const Tensor& bias,
- ActivationMode activation_mode, TensorFormat data_format,
- FilterTensorFormat filter_format, Tensor* output);
+ const Tensor& input, const Tensor& filter, int32 row_stride,
+ int32 col_stride, const Tensor& bias,
+ const ActivationMode& activation_mode,
+ const Eigen::PaddingType& padding, TensorFormat data_format,
+ Tensor* output);
};
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
deleted file mode 100644
index dc43af1158..0000000000
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
-
-#if GOOGLE_CUDA
-
-#include "tensorflow/core/kernels/conv_ops_gpu.h"
-#include "tensorflow/core/util/activation_mode.h"
-
-// TODO(pauldonnelly): Merge this file into core/kernels/conv_ops_gpu.h.
-
-namespace tensorflow {
-
-// Add additional parameters specific to fused convolutions.
-class FusedConvParameters : public ConvParameters {
- public:
- FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
- int64 out_depths, const SpatialArray& filter,
- const SpatialArray& stride, const SpatialArray& padding,
- DataType dtype, int device_id, bool has_side_input,
- ActivationMode activation_mode)
- : ConvParameters(batch, in_depths, in, out_depths, filter, stride,
- padding, dtype, device_id),
- activation_mode_(activation_mode),
- has_side_input_(has_side_input) {
- hash_code_ = Hash64Combine(hash_code_, has_side_input);
- hash_code_ = Hash64Combine(hash_code_, activation_mode);
- }
-
- bool operator==(const FusedConvParameters& other) const {
- return this->get_data_as_tuple() == other.get_data_as_tuple();
- }
-
- bool operator!=(const FusedConvParameters& other) const {
- return !(*this == other);
- }
-
- string ToString() const {
- return strings::StrCat(ConvParameters::ToString(), ", ", has_side_input_,
- ", ", activation_mode_, ", ");
- }
-
- private:
- using ParameterDataType =
- std::tuple<ConvParameters::ParameterDataType, bool, ActivationMode>;
-
- ParameterDataType get_data_as_tuple() const {
- return std::make_tuple(ConvParameters::get_data_as_tuple(), has_side_input_,
- activation_mode_);
- }
-
- ActivationMode activation_mode_;
- bool has_side_input_;
-};
-
-} // namespace tensorflow
-
-#endif // GOOGLE_CUDA
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 48f058b4c5..6134c5c699 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -33,73 +33,40 @@ string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; }
} // namespace
// --------------------------------------------------------------------------
-
-// TODO(pauldonnelly): Add support for double inputs and scales to this Op,
-// (currently Attr does not support double).
-
REGISTER_OP("FusedConv2DBiasActivation")
- .Input("conv_input: T")
+ .Input("input: T")
.Input("filter: T")
- .Input("bias: Tbias")
- .Input("side_input: T")
+ .Input("bias: T")
.Output("output: T")
- .Attr("T: {float, half, qint8}")
- .Attr("Tbias: {float, half}")
- .Attr("conv_input_scale: float = 1.0")
- .Attr("side_input_scale: float = 0.0")
+ .Attr("T: {float}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
- .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
- .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
- .Attr("activation_mode: {'Relu'} = 'Relu'")
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetAllActivationModeAttrString())
.SetShapeFn(shape_inference::FusedConvBiasActivationShape)
.Doc(R"doc(
- Computes a fused kernel which implements: 2-D convolution, adds side input,
- with separate scaling on convolution and side inputs, then adds bias and
- applies the RELU activation function to the result. Supports both float and
- qint8 data formats. In the case of qint8, the output is clipped to [0..127].
+ Computes a fused 2-D convolution, adds bias, and applies an activation function
+ on the output given 4-D `input`, 4-D `filter`, 1-D `bias` tensors and an activation mode.
- conv_input: A tensor with format as specified by `data_format` (see below).
- filter: A tensor with format depending on `data_format` as follows:
- "NHWC", "NCHW":
- `float [ filter_height, filter_width, in_channels, out_channels ]`
- "NCHW_VECT_C":
- `qint8 [ out_channels, in_channels, filter_height, filter_width ]`
- bias: 1-D float tensor with size matching the `out_channels` dimension of
- `filter`.
- Note: this tensor is still float, even if other inputs are qint8.
- side_input: A tensor with format as specified by `data_format` (see below).
- This tensor will be ignored and can be [] if side_input_scale == 0.
- Otherwise, the size of each dimension must match the `output` tensor.
- output: A tensor with format as specified by `data_format` (see below).
- The dimension sizes are determined automatically based on other inputs
- and attributes.
- T: The element data type of `conv_input`, `side_input` and `output` tensors.
- Note: must match with the `data_format`.
- Tbias: The element data type of `bias`.
- conv_input_scale: scalar float value to be multiplied by `conv_input`.
- (conceptually.. in reality it is applied after convolution).
- side_input_scale: scalar float value to be multiplied by `side_input`.
+ input: A 4-D tensor. The dimension order is interpreted according to the value
+ of `data_format`, see below for details.
+ filter: A 4-D tensor of shape
+ `[filter_height, filter_width, in_channels, out_channels]`
+ bias: 1-D with size of the `out_channels` dimension in filter.
+ output: A 4-D tensor. The dimension order is determined by the value of
+ `data_format`, see below for details.
+ T: The data type for the elements of input, filter, bias, and output Tensors.
strides: 1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
- Note: the stride for batch and channel dimensions must be 1.
padding: The type of padding algorithm to use.
- data_format: A string specifying the data format of `conv_input`,
- `side_input` and `output` tensors with the following options:
- "NHWC": `float [ batch, height, width, channels ]`
- "NCHW": `float [ batch, channels, height, width ]`
- "NCHW_VECT_C":
- `qint8 [ batch, channels / 4, height, width, channels % 4 ]`
- Note: for "NCHW_VECT_C", `channels` must be a multiple of 4.
- filter_format: A string specifying the data format of `filter`,
- "HWIO": `float [ kernel_height, kernel_width, input_channels,
- output_channels ]`
- "OIHW_VECT_I":
- `qint8 [ output_channels, input_channels / 4,
- kernel_height, kernel_width, input_channels % 4 ]`
- activation_mode: The activation applied to the output.
- Currently must be "Relu".
+ data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, height, width, channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, channels, height, width].
+ activation_mode: Specify the activation function to apply to the output tensor
+ of bias add. Currently only supports "Relu".
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
index 8f3f31bad0..41f986dd07 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
@@ -26,83 +26,62 @@ _fused_conv2d_bias_activation_op_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so"))
-# pylint: disable=redefined-builtin
-def fused_conv2d_bias_activation(conv_input,
- filter,
+def fused_conv2d_bias_activation(input_tensor,
+ filter_tensor,
bias,
- strides=None,
- padding=None,
- conv_input_scale=1.0,
- side_input_scale=0.0,
- side_input=None,
- activation_mode="Relu",
+ strides,
+ padding,
+ activation_mode,
data_format=None,
- filter_format=None,
name=None):
- """Fused 2D conv, bias and activation with optional side input.
+ """Computes a fused 2-D convolution, adds bias, and applies relu.
- Computes a fused 2-D convolution scaled by conv_input_scale,
- adds an optional side input scaled by side_input_scale, adds biases,
- and applies ReLU. As an equation:
- output = ReLU(conv_input_scale * Conv(conv_input, filter) +
- side_input_scale * side_input + bias)
- Note: In int8 mode, The ReLU will clip the output to the range [0..127].
+ input_tensor: A 4-D tensor. The dimension order is interpreted
+ according to the value of `data_format`, see below for details.
+ filter_tensor: A 4-D tensor of shape
+ `[filter_height, filter_width, in_channels, out_channels]`
+ bias: 1-D with size of the `out_channels` dimension in filter.
+ output: A 4-D tensor. The dimension order is determined by the value of
+ `data_format`, see below for details.
+ T: The data type for the elements of input, filter, bias, and output
+ Tensors.
+ strides: 1-D tensor of length 4. The stride of the sliding window for
+ each
+ dimension of `input`. The dimension order is determined by the value
+ of
+ `data_format`, see below for details.
+ padding: The type of padding algorithm to use.
+ data_format: Specify the data format of the input and output data. With
+ the
+ default format "NHWC", the data is stored in the order of:
+ [batch, height, width, channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, channels, height, width].
+ activation_mode: Specify the activation function to apply to the output
+ tensor
+ of bias add. Currently only supports "Relu".
Args:
- conv_input: A `Tensor` of the format specified by `data_format`.
- filter: A `Tensor` whose format depends on `data_format`:
- if `data_format` is "NCHW_VECT_C", filter should be "OIHW_VECT_I"
- otherwise, it should be "HWIO" format.
- bias: A 1-D `Tensor` of type `float32`, and dimensions equal to the
- number of output channels.
- strides: A list of 4 `ints` specifying convolution strides.
- if `data_format` is "NCHW" or "NCHW_VECT_C", the order should be NCHW.
- if `data_format` is "NHWC", the order should be NHWC.
+ input_tensor: A `Tensor`. Must be one of the following types: `float32`.
+ filter_tensor: A `Tensor`. Must have the same type as `input`.
+ bias: A `Tensor`. Must have the same type as `input`.
+ strides: A list of `ints`.
padding: A `string` from: `"SAME", "VALID"`.
- conv_input_scale: A scalar `float32` that will be multiplied by conv_input.
- This is optional and defaults to 1. However it should be set to
- specify the quantization scale when `data_format` is "NCHW_VECT_C".
- side_input_scale: A scalar `float32` that will be multiplied by side_input.
- This is optional and defaults to 0.
- side_input: A `Tensor` of the format specified by `data_format`.
- This is useful for imlementing ResNet blocks.
- activation_mode: (optional) currently must be the default "Relu".
- Note that in qint8 mode, it also clips to 127, so acts like ReluX.
- data_format: Specifies the data format.
- Possible values are:
- "NHWC" float [batch, height, width, channels]
- "NCHW" float [batch, channels, height, width]
- "NCHW_VECT_C" qint8 [batch, channels / 4, height, width, channels % 4]
- Defaults to `"NHWC"`.
- Performance is worst for `"NHWC"` and best for `"NCHW_VECT_C"`.
- filter_format: Specifies the filter format.
- Possible values are:
- "HWIO" float [kernel_height, kernel_width, input_channels,
- output_channels ]
- "OIHW" float [output_channels, input_channels, kernel_height,
- kernel_width ]
- "OIHW_VECT_I" qint8 [ output_channels, input_channels / 4,
- kernel_height, kernel_width, input_channels % 4 ]
- Defaults to `"HWIO"`.
+ activation_mode: A `string` from: `"Sigmoid", "Relu", "Relu6", "ReluX",
+ "Tanh", "BandPass"`.
+ data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
+ `"NHWC"`.
name: A name for the operation (optional).
Returns:
- A `Tensor` of the format specified by `data_format`.
+ A `Tensor`. Has the same type as `input`.
"""
- if strides is None:
- strides = [1, 1, 1, 1]
- if side_input is None:
- side_input = []
return gen_fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- conv_input,
- filter,
- bias,
- padding=padding,
+ input=input_tensor,
+ filter=filter_tensor,
+ bias=bias,
strides=strides,
- conv_input_scale=conv_input_scale,
- side_input_scale=side_input_scale,
- side_input=side_input,
+ padding=padding,
activation_mode=activation_mode,
data_format=data_format,
- filter_format=filter_format,
name=name)
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 3b8f7d6ed7..5d6a2fa3b8 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -19,16 +19,13 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-
from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@@ -487,8 +484,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
with self.test_session() as sess:
# Illegal strides.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "Convolutional strides are not supported in "
- "the batch or depth dimensions."):
+ "strides in the batch and depth"):
sess.run(
fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
array_ops.placeholder(dtypes.float32),
@@ -498,8 +494,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
padding="SAME",
activation_mode="Relu"))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "Convolutional strides are not supported in "
- "the batch or depth dimensions."):
+ "strides in the batch and depth"):
sess.run(
fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
array_ops.placeholder(dtypes.float32),
@@ -557,286 +552,6 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
return Test
-def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type):
- """Calculates the size of an output dimension of a strided convolution.
-
- Given the sizes of the corresponding dimension of the input and filter shapes,
- and the stride and padding_types, calculates the size of the output dimension.
- This function can be called separately for each input dimension.
-
- Args:
- input_dim: An `int` specifying the size of the input dimension.
- filter_dim: An `int` specifying the size of the filter dimension.
- stride: An `int` specifying the step size of the convolution along the
- input dimension.
- padding_type: either 'VALID' or 'SAME'.
-
- Returns:
- The size of the output dimension.
- """
- if padding_type == "VALID":
- return (input_dim - filter_dim + stride) // stride
- else: # padding_type == 'SAME'
- return (input_dim + stride - 1) // stride
-
-
-def NchwVectCToNchw(in_tensor):
- # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
- t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
- n = in_tensor.shape.dims[0].value
- c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
- h = in_tensor.shape.dims[2].value
- w = in_tensor.shape.dims[3].value
- return array_ops.reshape(t, [n, c, h, w])
-
-
-def OihwVectIToHwio(in_tensor):
- # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
- t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
- o = in_tensor.shape.dims[0].value
- i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
- h = in_tensor.shape.dims[2].value
- w = in_tensor.shape.dims[3].value
- return array_ops.reshape(t, [h, w, i, o])
-
-
-def NchwToNchwVectC(in_tensor):
- n, c, h, w = in_tensor.shape.as_list()
- assert c % 4 == 0
- t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
- return array_ops.transpose(t, [0, 1, 3, 4, 2])
-
-
-def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
- padding, strides, side_input_scale,
- side_input, biases):
- """Simulates the int8 fused 2-D convolution op using separate float ops.
-
- The arguments and return values have the same format, meanings and
- restrictions as the actual op.
- Args:
- conv_input_scale: A scalar 'float'.
- conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
- padding: A `string` from: `"SAME", "VALID"`.
- strides: A list of `ints`.
- side_input_scale: A scalar 'float'.
- side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- biases: A `Tensor` of type `float32` in NCHW layout.
- Returns:
- A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- """
- conv_result = nn_ops.conv2d(
- NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
- OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
- strides=strides,
- padding=padding,
- data_format="NCHW") * conv_input_scale
-
- conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw(
- gen_array_ops.dequantize(side_input, -128, 127))
-
- logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
-
- result, _, _ = gen_array_ops.quantize_v2(
- NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8)
- return result
-
-
-class FusedConvInt8Tests(test.TestCase):
- _test_params = [
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "VALID"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "VALID"
- },
- {
- "batch_size": 2,
- "input_channels": 16,
- "output_channels": 16,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.001,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 5,
- "filter_width": 5,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.001,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 7,
- "filter_width": 1,
- "vertical_stride": 2,
- "horizontal_stride": 1,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 1,
- "filter_width": 7,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- ]
-
- def runTest(self, test_param):
- batch_size = test_param["batch_size"]
- input_channels = test_param["input_channels"]
- output_channels = test_param["output_channels"]
- input_height = test_param["input_height"]
- input_width = test_param["input_width"]
- filter_height = test_param["filter_height"]
- filter_width = test_param["filter_width"]
- vertical_stride = test_param["vertical_stride"]
- horizontal_stride = test_param["horizontal_stride"]
- conv_input_scale = test_param["conv_input_scale"]
- side_input_scale = test_param["side_input_scale"]
- bias_scale = test_param["bias_scale"]
- padding_type = test_param["padding_type"]
-
- conv_input, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [batch_size, input_channels // 4, input_height, input_width, 4],
- minval=-0.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- kernel, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [
- output_channels, input_channels // 4, filter_height,
- filter_width, 4
- ],
- minval=-1.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- output_height = CalculateCovolvedOutputDim(input_height, filter_height,
- vertical_stride, padding_type)
- output_width = CalculateCovolvedOutputDim(input_width, filter_width,
- horizontal_stride, padding_type)
- print("output_height=", output_height, ", output_width=", output_width)
-
- side_input, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [batch_size, output_channels // 4, output_height, output_width, 4],
- minval=0.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- biases = random_ops.random_uniform(
- [output_channels],
- minval=-10 * bias_scale,
- maxval=20 * bias_scale,
- dtype=dtypes.float32)
-
- strides = [1, 1, vertical_stride, horizontal_stride]
-
- actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- conv_input,
- kernel,
- biases,
- strides=strides,
- padding=padding_type,
- conv_input_scale=conv_input_scale,
- side_input_scale=side_input_scale,
- side_input=side_input,
- data_format="NCHW_VECT_C",
- filter_format="OIHW_VECT_I")
-
- expected = SimulateFusedConv2dBiasActivationInt8(
- conv_input_scale, conv_input, kernel, padding_type, strides,
- side_input_scale, side_input, biases)
-
- with self.test_session(use_gpu=True) as sess:
- actual_y, expected_y = sess.run([actual, expected])
- print("actual_y = ", actual_y)
- print("expected_y = ", expected_y)
- self.assertTrue(np.array_equal(actual_y, expected_y))
-
- def testFusedConvInt8(self):
- if not test.is_gpu_available(
- cuda_only=True, min_cuda_compute_capability=(6, 1)):
- tf_logging.info("int8 test skipped because not run with --config=cuda or "
- "no GPUs with compute capability >= 6.1 are available.")
- return
- for test_param in self._test_params:
- self.runTest(test_param)
-
-
if __name__ == "__main__":
for index, (input_size_, filter_size_, output_size_, stride_,
padding_) in enumerate(GetShrunkInceptionShapes()):