aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-06 13:20:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 13:24:41 -0700
commit2b15badd96c651d4d191426975a1773dff4a03b8 (patch)
tree30406a237f324cb3993c8fbb2c49dd0c1f9ed624 /tensorflow/contrib/fused_conv
parentca65468a02d4b2ceb78cf5c130ad275a4eefe6bb (diff)
Add int8 version of fused_conv2d_bias_activation operator for the forward phase,
and support side_input and scaling parameters in float and int8 versions. PiperOrigin-RevId: 167763219
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r--tensorflow/contrib/fused_conv/BUILD13
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc698
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h31
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h74
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc77
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py107
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py289
7 files changed, 901 insertions, 388 deletions
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index f5d21278db..9b34cf1bdb 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -60,12 +60,14 @@ tf_kernel_library(
srcs = [
"kernels/fused_conv2d_bias_activation_op.cc",
"kernels/fused_conv2d_bias_activation_op.h",
+ "kernels/fused_conv_ops_gpu.h",
],
prefix = "fused_conv2d_bias_activation_op",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:stream_executor",
"//tensorflow/core/kernels:bounds_check_lib",
"//tensorflow/core/kernels:conv_2d_hdrs",
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
@@ -81,6 +83,7 @@ tf_custom_op_library(
srcs = [
"kernels/fused_conv2d_bias_activation_op.cc",
"kernels/fused_conv2d_bias_activation_op.h",
+ "kernels/fused_conv_ops_gpu.h",
"ops/fused_conv2d_bias_activation_op.cc",
],
deps = [
@@ -94,12 +97,8 @@ tf_custom_op_library(
)
tf_gen_op_libs(
- op_lib_names = [
- "fused_conv2d_bias_activation_op",
- ],
- deps = [
- "//tensorflow/core:lib_proto_parsing",
- ],
+ op_lib_names = ["fused_conv2d_bias_activation_op"],
+ deps = ["//tensorflow/core:lib_proto_parsing"],
)
tf_gen_op_wrapper_py(
@@ -109,7 +108,7 @@ tf_gen_op_wrapper_py(
cuda_py_test(
name = "fused_conv2d_bias_activation_op_test",
- size = "small",
+ size = "large",
srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
additional_deps = [
":fused_conv_py",
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index dc0701b234..675ff2be38 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#define EIGEN_USE_THREADS
-
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
@@ -31,8 +29,8 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/padding.h"
-#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#if GOOGLE_CUDA
@@ -40,38 +38,84 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/activation_mode.h"
#endif // GOOGLE_CUDA
+
namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device, typename T>
-struct LaunchConvOp;
+template <typename T>
+struct RawType {
+ using type = T;
+};
+
+template <>
+struct RawType<qint8> {
+ using type = int8;
+};
+
+// Template struct to convert int8x4 to int32.
+// (for NCHW_VECT_C with element type int8, we can consider it to be
+// an NCHW layout with element type int32 for operations like padding).
+template <typename T>
+struct Int8x4ToInt32 {
+ // By default, do not change T.
+ using type = T;
+};
+
+template <>
+struct Int8x4ToInt32<int8> {
+ using type = int32;
+};
-template <typename Device, typename T>
+// T is the element type of the conv_input, filter and side_input tensors.
+// BiasType is the element type of the bias tensor, which can be different.
+// ScaleType is the type used for conv_input_scale, side_input_scale.
+template <typename Device, typename T, typename BiasType, typename ScaleType>
class FusedConv2DBiasActivationOp : public OpKernel {
public:
explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
: OpKernel(context) {
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ string data_format_str, filter_format_str;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
+ OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("filter_format", &filter_format_str));
OP_REQUIRES(context,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW),
- errors::InvalidArgument("Current implementation only supports "
- "NHWC and NCHW data formats."));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- OP_REQUIRES(context, strides_.size() == 4,
+ FilterFormatFromString(filter_format_str, &filter_format_),
+ errors::InvalidArgument("Invalid filter format"));
+
+ std::vector<int32> strides;
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
+ OP_REQUIRES(context, strides.size() == 4,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 dimensions"));
+
+ stride_rows_ = GetTensorDim(strides, data_format_, 'H');
+ stride_cols_ = GetTensorDim(strides, data_format_, 'W');
OP_REQUIRES(
context,
- (GetTensorDim(strides_, data_format_, 'N') == 1 &&
- GetTensorDim(strides_, data_format_, 'C') == 1),
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ (GetTensorDim(strides, data_format_, 'N') == 1 &&
+ GetTensorDim(strides, data_format_, 'C') == 1),
+ errors::InvalidArgument("Convolutional strides are not supported in "
+ "the batch or depth dimensions."));
+
+ // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
+ constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
+
+ // Note: Only NCHW_VECT_C format is supported for int8.
+ // This is because it is expected to be the fastest, and our previous tests
+ // found cudnn 6 does not fully support the other formats for int8 mode.
+ OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
+ errors::InvalidArgument(
+ "qint8 should be used with data_format NCHW_VECT_C."));
+
+ OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)),
+ errors::InvalidArgument(
+ "qint8 should be used with filter_format OIHW_VECT_I."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_));
+ eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_);
string activation_mode_str;
OP_REQUIRES_OK(context,
context->GetAttr("activation_mode", &activation_mode_str));
@@ -79,130 +123,111 @@ class FusedConv2DBiasActivationOp : public OpKernel {
&activation_mode_));
OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
errors::InvalidArgument("Current implementation only supports "
- "relu as the activation mode."));
+ "RELU as the activation function."));
cudnn_use_autotune_ = CudnnUseAutotune();
+ float conv_input_scale_flt, side_input_scale_flt;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("conv_input_scale", &conv_input_scale_flt));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("side_input_scale", &side_input_scale_flt));
+ conv_input_scale_ = conv_input_scale_flt;
+ side_input_scale_ = side_input_scale_flt;
+ }
+
+ Status CheckShape(const Tensor& tensor, const string& tensor_name) {
+ const int num_dims = tensor.dims();
+ for (int i = 0; i < num_dims; i++) {
+ if (!FastBoundsCheck(tensor.dim_size(i),
+ std::numeric_limits<int32>::max())) {
+ return errors::InvalidArgument(tensor_name, " dimension ", i,
+ " too large");
+ }
+ }
+ // If there is a 5th dimension it is the VECT_C or VECT_I dimension.
+ if (num_dims == 5 && tensor.dim_size(4) != 4) {
+ return errors::InvalidArgument("The last dimension of ", tensor_name,
+ " must be of size 4 for qint8.");
+ }
+ return Status::OK();
}
void Compute(OpKernelContext* context) override {
- // Input tensor is one of the following shapes:
- // [ batch, in_rows, in_cols, in_depth ] (for NHWC data format)
- // [ batch, in_depth, in_rows, in_cols ] (for NCHW data format)
- const Tensor& input = context->input(0);
+ // The conv_input tensor is one of the following formats:
+ // NHWC, NCHW, NCHW_VECT_C.
+ const Tensor& conv_input = context->input(0);
+ OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
- // Input filter is of the following dimensions:
- // [ filter_rows, filter_cols, in_depth, out_depth ]
+ // The filter tensor is one of the following formats:
+ // HWIO, OIHW, OIHW_VECT_I.
const Tensor& filter = context->input(1);
+ OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
- // Input bias is a 1-D tensor the size of the last
- // dimension of Output tensor
+ // Input bias is a 1-D tensor, with size matching output depth.
const Tensor& bias = context->input(2);
+ OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- // Bias should be a 1-D tensor.
- OP_REQUIRES(context, bias.dims() == 1,
- errors::InvalidArgument("bias must be 1-dimensional: ",
- bias.shape().DebugString()));
-
- for (int i = 0; i < 4; i++) {
- OP_REQUIRES(context,
- FastBoundsCheck(filter.dim_size(i),
- std::numeric_limits<int32>::max()),
- errors::InvalidArgument("filter dimension too large"));
- OP_REQUIRES(
- context,
- FastBoundsCheck(input.dim_size(i), std::numeric_limits<int32>::max()),
- errors::InvalidArgument("input dimension too large"));
+ // If side_input_scale != 0, then side_input is not ignored and
+ // has the same type and dimensions as the output.
+ const Tensor& side_input = context->input(3);
+ if (side_input_scale_ != 0) {
+ OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
}
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- OP_REQUIRES(context, in_depth == filter.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", filter.dim_size(2)));
-
- // The last dimension for filter is out_depth.
- const int32 out_depth = static_cast<int32>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- const int32 input_rows = static_cast<int32>(input_rows_raw);
- const int32 filter_rows = static_cast<int32>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- const int32 input_cols = static_cast<int32>(input_cols_raw);
- const int32 filter_cols = static_cast<int32>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- const int32 batch = static_cast<int32>(batch_raw);
-
- // For now we take the stride from the second and third dimensions only (we
- // do not support striding on the batch or depth dimension).
- const int32 stride_rows =
- static_cast<int32>(GetTensorDim(strides_, data_format_, 'H'));
- const int32 stride_cols =
- static_cast<int32>(GetTensorDim(strides_, data_format_, 'W'));
- const int32 bias_size = static_cast<int32>(bias.dim_size(0));
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context,
- GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
- padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context,
- GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
- padding_, &out_cols, &pad_cols));
- // Output tensor is of the following dimensions:
- // [ in_batch, out_rows, out_cols, out_depth ]
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ // TODO(pauldonnelly): Switch to a more efficient mechanism to access
+ // dimension indexes and per-dimension attributes.
+ const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H');
+ const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W');
+ const int32 output_depth = GetFilterDim(filter, filter_format_, 'O');
+
+ const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N');
+ const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H');
+ const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W');
+
+ int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows,
+ stride_rows_, padding_type_,
+ &output_rows, &pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols,
+ stride_cols_, padding_type_,
+ &output_cols, &pad_cols));
+ // Initialize the output tensor shape according to data_format_
+ TensorShape output_shape = ShapeFromFormat(
+ data_format_, batch_size, output_rows, output_cols, output_depth);
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
-
- // Bias size should be the same as the size of the channel dimension of
- // output.
- OP_REQUIRES(context, bias_size == out_depth,
- errors::InvalidArgument(
- "bias size should equal the channel "
- "dimension size of output. bias shape: ",
- bias.shape().DebugString() +
- ", output shape: " + output->shape().DebugString()));
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- VLOG(2) << "FusedConv2DBiasActivation: in_depth = " << in_depth
- << ", input_cols = " << input_cols
+ VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = "
+ << conv_input_cols << ", conv_input_rows = " << conv_input_rows
<< ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
<< ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", bias_size = " << bias_size << ", out_depth = " << out_depth;
+ << ", stride_cols = " << stride_cols_
+ << ", stride_rows = " << stride_rows_
+ << ", output_depth = " << output_depth
+ << ", output_cols = " << output_cols
+ << ", output_rows = " << output_rows
+ << ", output_shape.num_elements = " << output_shape.num_elements();
// If there is nothing to compute, return.
- if (out_shape.num_elements() == 0) {
+ if (output_shape.num_elements() == 0) {
return;
}
- launcher_.launch(context, cudnn_use_autotune_, input, filter, stride_rows,
- stride_cols, bias, activation_mode_,
- BrainPadding2EigenPadding(padding_), data_format_, output);
+
+ launcher_.launch(context, cudnn_use_autotune_, conv_input,
+ conv_input_scale_, filter, stride_rows_, stride_cols_,
+ eigen_padding_type_, side_input, side_input_scale_, bias,
+ activation_mode_, data_format_, filter_format_, output);
}
private:
- std::vector<int32> strides_;
- Padding padding_;
+ int32 stride_rows_, stride_cols_;
+ Padding padding_type_;
+ Eigen::PaddingType eigen_padding_type_;
ActivationMode activation_mode_;
TensorFormat data_format_;
- LaunchFusedConv2DBiasActivationOp<Device, T> launcher_;
+ FilterTensorFormat filter_format_;
+ ScaleType conv_input_scale_;
+ ScaleType side_input_scale_;
+ LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
bool cudnn_use_autotune_;
TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp);
@@ -211,67 +236,72 @@ class FusedConv2DBiasActivationOp : public OpKernel {
#if GOOGLE_CUDA
namespace dnn = ::perftools::gputools::dnn;
-dnn::ActivationMode BrainActivationMode2CudnnActivationMode(
- ActivationMode activation_mode) {
- switch (activation_mode) {
- case ActivationMode::SIGMOID:
- return dnn::ActivationMode::kSigmoid;
- case ActivationMode::RELU:
- return dnn::ActivationMode::kRelu;
- case ActivationMode::RELUX:
- return dnn::ActivationMode::kReluX;
- case ActivationMode::RELU6:
- return dnn::ActivationMode::kRelu6;
- case ActivationMode::TANH:
- return dnn::ActivationMode::kTanh;
- case ActivationMode::BANDPASS:
- return dnn::ActivationMode::kBandPass;
- }
- // Prevent compiler warning about missing return
- return dnn::ActivationMode::kRelu;
-}
-
// A dummy type to group forward convolution autotune results together.
struct ConvBiasActivationAutoTuneGroup {
static string name() { return "ConvBiasActivation"; }
};
-typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, FusedConvParameters,
+ dnn::AlgorithmConfig>
AutoTuneConvBiasActivation;
-template <typename T>
-void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
- OpKernelContext* ctx, bool cudnn_use_autotune, const Tensor& input_param,
- const Tensor& filter, int32 row_stride, int32 col_stride,
- const Tensor& bias, const ActivationMode& activation_mode,
- const Eigen::PaddingType& padding, TensorFormat data_format,
- Tensor* output) {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmType;
- using perftools::gputools::dnn::ProfileResult;
- using perftools::gputools::dnn::kDefaultAlgorithm;
+// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it
+// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions.
+template <typename T, size_t NDIMS>
+Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor,
+ int batch_size, int rows, int cols, int depth,
+ Tensor* transformed_tensor, const Tensor** result) {
+ TensorShape nchw_shape =
+ ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth);
+ if (depth > 1) {
+ TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
+ transformed_tensor));
+ functor::NHWCToNCHW<GPUDevice, T, NDIMS>()(
+ ctx->eigen_device<GPUDevice>(), nhwc_tensor.tensor<T, NDIMS>(),
+ transformed_tensor->tensor<T, NDIMS>());
+ } else {
+ // If depth <= 1, then just reshape.
+ CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape));
+ }
+ *result = transformed_tensor;
+ return Status::OK();
+}
+
+template <typename T, typename BiasType, typename ScaleType>
+void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
+ launch(OpKernelContext* ctx, bool cudnn_use_autotune,
+ const Tensor& conv_input_param, ScaleType conv_input_scale,
+ const Tensor& filter_param, int32 row_stride, int32 col_stride,
+ const Eigen::PaddingType& padding, const Tensor& side_input_param,
+ ScaleType side_input_scale, const Tensor& bias,
+ ActivationMode activation_mode, TensorFormat data_format,
+ FilterTensorFormat filter_format, Tensor* output_param) {
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
- Tensor input = input_param;
-
- perftools::gputools::dnn::ActivationMode cudnn_activation_mode =
- BrainActivationMode2CudnnActivationMode(activation_mode);
-
// TODO(yangzihao): refactor all the complicated/duplicated code in regular
// conv ops to a shared conv utility.
- int32 padding_rows = 0;
- int32 padding_cols = 0;
- const int64 in_batch = GetTensorDim(input, data_format, 'N');
- int64 in_rows = GetTensorDim(input, data_format, 'H');
- int64 in_cols = GetTensorDim(input, data_format, 'W');
- const int64 in_depths = GetTensorDim(input, data_format, 'C');
- const int64 out_batch = GetTensorDim(*output, data_format, 'N');
- const int64 out_rows = GetTensorDim(*output, data_format, 'H');
- const int64 out_cols = GetTensorDim(*output, data_format, 'W');
- const int64 out_depths = GetTensorDim(*output, data_format, 'C');
- const int64 patch_rows = filter.dim_size(0);
- const int64 patch_cols = filter.dim_size(1);
+
+ // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
+ constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
+ constexpr int rank = is_int8x4 ? 5 : 4;
+ constexpr int vect = is_int8x4 ? 4 : 1;
+
+ const int batch_size = GetTensorDim(conv_input_param, data_format, 'N');
+ int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H');
+ int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W');
+
+ const int conv_input_depth =
+ GetTensorDim(conv_input_param, data_format, 'C') * vect;
+ const int output_rows = GetTensorDim(*output_param, data_format, 'H');
+ const int output_cols = GetTensorDim(*output_param, data_format, 'W');
+ const int output_depth = GetFilterDim(filter_param, filter_format, 'O');
+ const int filter_rows = GetFilterDim(filter_param, filter_format, 'H');
+ const int filter_cols = GetFilterDim(filter_param, filter_format, 'W');
+ int padding_rows = 0;
+ int padding_cols = 0;
+ const Tensor* conv_input = &conv_input_param;
+
+ Tensor maybe_padded_conv_input;
if (padding == Eigen::PADDING_SAME) {
// Total padding on rows and cols is
// Pr = (R' - 1) * S + Kr - R
@@ -281,114 +311,152 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
// we pad more on the right and bottom than on the top and left.
- padding_rows =
- std::max<int32>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
- padding_cols =
- std::max<int32>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
- const int rows_parity = padding_rows & 1;
- const int cols_parity = padding_cols & 1;
- if ((rows_parity | cols_parity) != 0) {
+ padding_rows = std::max<int>(
+ 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows);
+ padding_cols = std::max<int>(
+ 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols);
+ const int padding_rows_parity = padding_rows & 1;
+ const int padding_cols_parity = padding_cols & 1;
+ if ((padding_rows_parity | padding_cols_parity) != 0) {
Tensor transformed_input;
- int64 new_in_rows = in_rows + rows_parity;
- int64 new_in_cols = in_cols + cols_parity;
+ const int new_conv_input_rows = conv_input_rows + padding_rows_parity;
+ const int new_conv_input_cols = conv_input_cols + padding_cols_parity;
+
+ using VectT = typename Int8x4ToInt32<typename RawType<T>::type>::type;
+ auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format;
+
OP_REQUIRES_OK(
- ctx,
- ctx->allocate_temp(DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format, in_batch, new_in_rows,
- new_in_cols, in_depths),
- &transformed_input));
-
- functor::PadInput<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
- {{0, 0}}, {{rows_parity, cols_parity}},
- To32Bit(transformed_input.tensor<T, 4>()), data_format);
-
- input = transformed_input;
- in_rows = new_in_rows;
- in_cols = new_in_cols;
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format, batch_size, new_conv_input_rows,
+ new_conv_input_cols, conv_input_depth),
+ &maybe_padded_conv_input));
+
+ auto conv_input_eigen_tensor =
+ To32Bit(conv_input_param.reinterpret_last_dimension<VectT, 4>());
+ auto padded_conv_input_eigen_tensor = To32Bit(
+ maybe_padded_conv_input.reinterpret_last_dimension<VectT, 4>());
+
+ functor::PadInput<GPUDevice, VectT, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), conv_input_eigen_tensor, {{0, 0}},
+ {{padding_rows_parity, padding_cols_parity}},
+ padded_conv_input_eigen_tensor, pad_data_format);
+
+ conv_input = &maybe_padded_conv_input;
+ conv_input_rows = new_conv_input_rows;
+ conv_input_cols = new_conv_input_cols;
}
}
- if (data_format == FORMAT_NHWC) {
- // Convert the input tensor from NHWC to NCHW.
- TensorShape nchw_shape =
- ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
- if (in_depths > 1) {
- Tensor transformed_input;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- nchw_shape, &transformed_input));
- functor::NHWCToNCHW<GPUDevice, T, 4>()(
- ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor&>(input).tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
- input = transformed_input;
- } else {
- // If depth <= 1, then just reshape.
- CHECK(input.CopyFrom(input, nchw_shape));
+ Tensor maybe_transformed_conv_input, maybe_transformed_side_input;
+ Tensor maybe_transformed_output;
+ const Tensor* side_input = &side_input_param;
+ Tensor* output = output_param;
+
+ // NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary
+ // and inefficient, but it is actually both a time and code size optimization,
+ // since 'is_int8x4' is a constexpr determined by the template parameter.
+ if (!is_int8x4 && data_format == FORMAT_NHWC) {
+ OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW<T, rank>(
+ ctx, *conv_input, batch_size, conv_input_rows,
+ conv_input_cols, conv_input_depth,
+ &maybe_transformed_conv_input, &conv_input)));
+ if (side_input_scale != 0) {
+ OP_REQUIRES_OK(
+ ctx, (TransformNHWCToNCHW<T, rank>(
+ ctx, side_input_param, batch_size, output_rows, output_cols,
+ output_depth, &maybe_transformed_side_input, &side_input)));
+ }
+ if (output_depth > 1) {
+ // Allocate a tensor for the NCHW output of the kernel and point output
+ // to it. Afterwards, we will transform it to NHWC while copying back to
+ // 'output_param'.
+ TensorShape nchw_shape = ShapeFromFormat(
+ FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth);
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
+ &maybe_transformed_output));
+ output = &maybe_transformed_output;
}
}
- CHECK(padding_rows >= 0 && padding_cols >= 0)
- << "Negative row or col paddings: (" << padding_rows << ", "
- << padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
- input_desc.set_count(in_batch)
- .set_feature_map_count(in_depths)
- .set_height(in_rows)
- .set_width(in_cols)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
- output_desc.set_count(out_batch)
- .set_height(out_rows)
- .set_width(out_cols)
- .set_feature_map_count(out_depths)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(filter.dim_size(0))
- .set_input_filter_width(filter.dim_size(1))
- .set_input_feature_map_count(filter.dim_size(2))
- .set_output_feature_map_count(filter.dim_size(3));
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4
+ : dnn::DataLayout::kBatchDepthYX;
+ constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4
+ : dnn::FilterLayout::kOutputInputYX;
+
+ dnn::BatchDescriptor conv_input_desc;
+ conv_input_desc.set_count(batch_size)
+ .set_feature_map_count(conv_input_depth)
+ .set_height(conv_input_rows)
+ .set_width(conv_input_cols)
+ .set_layout(data_layout);
+ dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter_rows)
+ .set_input_filter_width(filter_cols)
+ .set_input_feature_map_count(conv_input_depth)
+ .set_output_feature_map_count(output_depth)
+ .set_layout(filter_layout);
+ dnn::BatchDescriptor side_input_desc;
+ side_input_desc.set_count(batch_size)
+ .set_height(output_rows)
+ .set_width(output_cols)
+ .set_feature_map_count(output_depth)
+ .set_layout(data_layout);
+ dnn::BatchDescriptor bias_desc;
+ bias_desc.set_count(1)
+ .set_height(1)
+ .set_width(1)
+ .set_feature_map_count(output_depth)
+ .set_layout(dnn::DataLayout::kBatchDepthYX);
+ dnn::BatchDescriptor output_desc;
+ output_desc.set_count(batch_size)
+ .set_height(output_rows)
+ .set_width(output_cols)
+ .set_feature_map_count(output_depth)
+ .set_layout(data_layout);
+ dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
- // Shuffles a filter tensor from:
- // [<spatial_dims>, in, out]
- // to:
- // [out, in, <spatial_dims>]
- // TODO(yangzihao): Support a data layout tag for the filter weights, and only
- // do the transform if the weights are not already in the correct layout.
- Tensor transformed_filter;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({filter.dim_size(3), filter.dim_size(2),
- filter.dim_size(0), filter.dim_size(1)}),
- &transformed_filter));
-
- functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
- To32Bit(transformed_filter.tensor<T, 4>()));
-
- Tensor transformed_output;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
- out_cols, out_depths),
- &transformed_output));
-
- auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
- input.template flat<T>().size());
+ Tensor maybe_transformed_filter;
+ const Tensor* filter;
+ if (is_int8x4) {
+ // We have already checked filter is OIHW_VECT_I in the constructor.
+ filter = &filter_param;
+ } else if (filter_format == FORMAT_HWIO) {
+ // Shuffle filter tensor from HWIO to OIHW:
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFilterFormat(
+ FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
+ &maybe_transformed_filter));
+ functor::TransformFilter<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()),
+ To32Bit(maybe_transformed_filter.tensor<T, 4>()));
+ filter = &maybe_transformed_filter;
+ }
+
+ auto conv_input_ptr =
+ AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
+ conv_input->template flat<T>().data()),
+ conv_input->template flat<T>().size());
auto filter_ptr =
- AsDeviceMemory(transformed_filter.template flat<T>().data(),
- transformed_filter.template flat<T>().size());
+ AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
+ filter->template flat<T>().data()),
+ filter->template flat<T>().size());
+ auto side_input_ptr =
+ AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
+ side_input->template flat<T>().data()),
+ side_input->template flat<T>().size());
auto output_ptr =
- AsDeviceMemory(transformed_output.template flat<T>().data(),
- transformed_output.template flat<T>().size());
-
- auto bias_ptr = AsDeviceMemory(bias.template flat<T>().data(),
- bias.template flat<T>().size());
+ AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
+ output->template flat<T>().data()),
+ output->template flat<T>().size());
+ auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(),
+ bias.template flat<BiasType>().size());
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
@@ -396,38 +464,42 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
);
int device_id = stream->parent()->device_ordinal();
- DataType dtype = input.dtype();
- ConvParameters conv_parameters = {
- in_batch,
- in_depths,
- {{in_rows, in_cols}},
- out_depths,
- {{patch_rows, patch_cols}},
+ FusedConvParameters fused_conv_parameters = {
+ batch_size,
+ conv_input_depth,
+ {{conv_input_rows, conv_input_cols}},
+ output_depth,
+ {{filter_rows, filter_cols}},
{{row_stride, col_stride}},
{{padding_rows, padding_cols}},
- dtype,
+ conv_input->dtype(),
device_id,
+ (side_input_scale != 0),
+ activation_mode,
};
- AlgorithmConfig algorithm_config;
+ dnn::AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
- conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmType> algorithms;
+ fused_conv_parameters, &algorithm_config)) {
+ std::vector<dnn::AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
- ProfileResult best_result;
- ProfileResult best_result_no_scratch;
+ fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(),
+ &algorithms));
+ dnn::ProfileResult best_result;
+ dnn::ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- ProfileResult profile_result;
+ dnn::ProfileResult profile_result;
bool cudnn_launch_status =
stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ ->ThenFusedConvolveWithAlgorithm(
+ conv_input_desc, conv_input_ptr, conv_input_scale,
+ filter_desc, filter_ptr, conv_desc, side_input_ptr,
+ side_input_scale, bias_desc, bias_ptr,
+ dnn::ActivationMode::kRelu, output_desc, &output_ptr,
+ &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
&profile_result)
.ok();
if (cudnn_launch_status) {
@@ -454,42 +526,68 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
- AutoTuneConvBiasActivation::GetInstance()->Insert(conv_parameters,
+ AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
algorithm_config);
}
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
- &scratch_allocator, algorithm_config,
+ ->ThenFusedConvolveWithAlgorithm(
+ conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
+ filter_ptr, conv_desc, side_input_ptr, side_input_scale,
+ bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc,
+ &output_ptr, &scratch_allocator, algorithm_config,
/*output_profile_result=*/nullptr)
.ok();
if (!cudnn_launch_status) {
- ctx->SetStatus(errors::Internal(
- "cuDNN launch failure : input shape(", input.shape().DebugString(),
- ") filter shape(", filter.shape().DebugString(), ")"));
+ ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(",
+ conv_input->shape().DebugString(),
+ ") filter shape(",
+ filter->shape().DebugString(), ")"));
}
- // Convert the output tensor back from NCHW to NHWC.
- if (data_format == FORMAT_NHWC) {
+ // Convert the output tensor back from NCHW to NHWC if necessary.
+ if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) {
functor::NCHWToNHWC<GPUDevice, T, 4>()(
ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
- output->tensor<T, 4>());
- } else {
- *output = transformed_output;
+ const_cast<const Tensor*>(output)->tensor<T, 4>(),
+ output_param->tensor<T, 4>());
}
}
+// Forward declarations of the functor specializations for GPU used above.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void PadInput<GPUDevice, T, int, 4>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const std::array<int, 2>& padding_left, \
+ const std::array<int, 2>& padding_right, \
+ typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
+ extern template struct PadInput<GPUDevice, T, int, 4>;
+
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(int32);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
// Registration of the GPU implementations.
-REGISTER_KERNEL_BUILDER(Name("FusedConv2DBiasActivation")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T"),
- FusedConv2DBiasActivationOp<GPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("FusedConv2DBiasActivation")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<float>("Tbias"),
+ FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("FusedConv2DBiasActivation")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<qint8>("T")
+ .TypeConstraint<float>("Tbias"),
+ FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
index d71b26cf1d..7534f5797c 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
@@ -24,7 +24,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -33,27 +33,30 @@ namespace tensorflow {
// Forward declaration.
class OpKernelContext;
-template <typename Device, typename T>
+template <typename Device, typename T, typename BiasType, typename ScaleType>
class LaunchFusedConv2DBiasActivationOp {
public:
void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Tensor& bias,
- const ActivationMode& activation_mode,
- const Eigen::PaddingType& padding, TensorFormat data_format,
- Tensor* output);
+ const Tensor& conv_input, ScaleType conv_input_scale,
+ const Tensor& filter, int32 row_stride, int32 col_stride,
+ const Eigen::PaddingType& padding, const Tensor& side_input,
+ ScaleType side_input_scale, const Tensor& bias,
+ ActivationMode activation_mode, TensorFormat data_format,
+ FilterTensorFormat filter_format, Tensor* output);
};
#ifdef GOOGLE_CUDA
-template <typename T>
-class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T> {
+template <typename T, typename BiasType, typename ScaleType>
+class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T, BiasType,
+ ScaleType> {
public:
void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int32 row_stride,
- int32 col_stride, const Tensor& bias,
- const ActivationMode& activation_mode,
- const Eigen::PaddingType& padding, TensorFormat data_format,
- Tensor* output);
+ const Tensor& conv_input, ScaleType conv_input_scale,
+ const Tensor& filter, int32 row_stride, int32 col_stride,
+ const Eigen::PaddingType& padding, const Tensor& side_input,
+ ScaleType side_input_scale, const Tensor& bias,
+ ActivationMode activation_mode, TensorFormat data_format,
+ FilterTensorFormat filter_format, Tensor* output);
};
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
new file mode 100644
index 0000000000..dc43af1158
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/core/util/activation_mode.h"
+
+// TODO(pauldonnelly): Merge this file into core/kernels/conv_ops_gpu.h.
+
+namespace tensorflow {
+
+// Add additional parameters specific to fused convolutions.
+class FusedConvParameters : public ConvParameters {
+ public:
+ FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
+ int64 out_depths, const SpatialArray& filter,
+ const SpatialArray& stride, const SpatialArray& padding,
+ DataType dtype, int device_id, bool has_side_input,
+ ActivationMode activation_mode)
+ : ConvParameters(batch, in_depths, in, out_depths, filter, stride,
+ padding, dtype, device_id),
+ activation_mode_(activation_mode),
+ has_side_input_(has_side_input) {
+ hash_code_ = Hash64Combine(hash_code_, has_side_input);
+ hash_code_ = Hash64Combine(hash_code_, activation_mode);
+ }
+
+ bool operator==(const FusedConvParameters& other) const {
+ return this->get_data_as_tuple() == other.get_data_as_tuple();
+ }
+
+ bool operator!=(const FusedConvParameters& other) const {
+ return !(*this == other);
+ }
+
+ string ToString() const {
+ return strings::StrCat(ConvParameters::ToString(), ", ", has_side_input_,
+ ", ", activation_mode_, ", ");
+ }
+
+ private:
+ using ParameterDataType =
+ std::tuple<ConvParameters::ParameterDataType, bool, ActivationMode>;
+
+ ParameterDataType get_data_as_tuple() const {
+ return std::make_tuple(ConvParameters::get_data_as_tuple(), has_side_input_,
+ activation_mode_);
+ }
+
+ ActivationMode activation_mode_;
+ bool has_side_input_;
+};
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 6134c5c699..48f058b4c5 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -33,40 +33,73 @@ string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; }
} // namespace
// --------------------------------------------------------------------------
+
+// TODO(pauldonnelly): Add support for double inputs and scales to this Op,
+// (currently Attr does not support double).
+
REGISTER_OP("FusedConv2DBiasActivation")
- .Input("input: T")
+ .Input("conv_input: T")
.Input("filter: T")
- .Input("bias: T")
+ .Input("bias: Tbias")
+ .Input("side_input: T")
.Output("output: T")
- .Attr("T: {float}")
+ .Attr("T: {float, half, qint8}")
+ .Attr("Tbias: {float, half}")
+ .Attr("conv_input_scale: float = 1.0")
+ .Attr("side_input_scale: float = 0.0")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
- .Attr(GetConvnetDataFormatAttrString())
- .Attr(GetAllActivationModeAttrString())
+ .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
+ .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
+ .Attr("activation_mode: {'Relu'} = 'Relu'")
.SetShapeFn(shape_inference::FusedConvBiasActivationShape)
.Doc(R"doc(
- Computes a fused 2-D convolution, adds bias, and applies an activation function
- on the output given 4-D `input`, 4-D `filter`, 1-D `bias` tensors and an activation mode.
+ Computes a fused kernel which implements: 2-D convolution, adds side input,
+ with separate scaling on convolution and side inputs, then adds bias and
+ applies the RELU activation function to the result. Supports both float and
+ qint8 data formats. In the case of qint8, the output is clipped to [0..127].
- input: A 4-D tensor. The dimension order is interpreted according to the value
- of `data_format`, see below for details.
- filter: A 4-D tensor of shape
- `[filter_height, filter_width, in_channels, out_channels]`
- bias: 1-D with size of the `out_channels` dimension in filter.
- output: A 4-D tensor. The dimension order is determined by the value of
- `data_format`, see below for details.
- T: The data type for the elements of input, filter, bias, and output Tensors.
+ conv_input: A tensor with format as specified by `data_format` (see below).
+ filter: A tensor with format depending on `data_format` as follows:
+ "NHWC", "NCHW":
+ `float [ filter_height, filter_width, in_channels, out_channels ]`
+ "NCHW_VECT_C":
+ `qint8 [ out_channels, in_channels, filter_height, filter_width ]`
+ bias: 1-D float tensor with size matching the `out_channels` dimension of
+ `filter`.
+ Note: this tensor is still float, even if other inputs are qint8.
+ side_input: A tensor with format as specified by `data_format` (see below).
+ This tensor will be ignored and can be [] if side_input_scale == 0.
+ Otherwise, the size of each dimension must match the `output` tensor.
+ output: A tensor with format as specified by `data_format` (see below).
+ The dimension sizes are determined automatically based on other inputs
+ and attributes.
+ T: The element data type of `conv_input`, `side_input` and `output` tensors.
+ Note: must match with the `data_format`.
+ Tbias: The element data type of `bias`.
+ conv_input_scale: scalar float value to be multiplied by `conv_input`.
+ (conceptually.. in reality it is applied after convolution).
+ side_input_scale: scalar float value to be multiplied by `side_input`.
strides: 1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
+ Note: the stride for batch and channel dimensions must be 1.
padding: The type of padding algorithm to use.
- data_format: Specify the data format of the input and output data. With the
- default format "NHWC", the data is stored in the order of:
- [batch, height, width, channels].
- Alternatively, the format could be "NCHW", the data storage order of:
- [batch, channels, height, width].
- activation_mode: Specify the activation function to apply to the output tensor
- of bias add. Currently only supports "Relu".
+ data_format: A string specifying the data format of `conv_input`,
+ `side_input` and `output` tensors with the following options:
+ "NHWC": `float [ batch, height, width, channels ]`
+ "NCHW": `float [ batch, channels, height, width ]`
+ "NCHW_VECT_C":
+ `qint8 [ batch, channels / 4, height, width, channels % 4 ]`
+ Note: for "NCHW_VECT_C", `channels` must be a multiple of 4.
+ filter_format: A string specifying the data format of `filter`,
+ "HWIO": `float [ kernel_height, kernel_width, input_channels,
+ output_channels ]`
+ "OIHW_VECT_I":
+ `qint8 [ output_channels, input_channels / 4,
+ kernel_height, kernel_width, input_channels % 4 ]`
+ activation_mode: The activation applied to the output.
+ Currently must be "Relu".
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
index 41f986dd07..8f3f31bad0 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py
@@ -26,62 +26,83 @@ _fused_conv2d_bias_activation_op_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so"))
-def fused_conv2d_bias_activation(input_tensor,
- filter_tensor,
+# pylint: disable=redefined-builtin
+def fused_conv2d_bias_activation(conv_input,
+ filter,
bias,
- strides,
- padding,
- activation_mode,
+ strides=None,
+ padding=None,
+ conv_input_scale=1.0,
+ side_input_scale=0.0,
+ side_input=None,
+ activation_mode="Relu",
data_format=None,
+ filter_format=None,
name=None):
- """Computes a fused 2-D convolution, adds bias, and applies relu.
+ """Fused 2D conv, bias and activation with optional side input.
- input_tensor: A 4-D tensor. The dimension order is interpreted
- according to the value of `data_format`, see below for details.
- filter_tensor: A 4-D tensor of shape
- `[filter_height, filter_width, in_channels, out_channels]`
- bias: 1-D with size of the `out_channels` dimension in filter.
- output: A 4-D tensor. The dimension order is determined by the value of
- `data_format`, see below for details.
- T: The data type for the elements of input, filter, bias, and output
- Tensors.
- strides: 1-D tensor of length 4. The stride of the sliding window for
- each
- dimension of `input`. The dimension order is determined by the value
- of
- `data_format`, see below for details.
- padding: The type of padding algorithm to use.
- data_format: Specify the data format of the input and output data. With
- the
- default format "NHWC", the data is stored in the order of:
- [batch, height, width, channels].
- Alternatively, the format could be "NCHW", the data storage order of:
- [batch, channels, height, width].
- activation_mode: Specify the activation function to apply to the output
- tensor
- of bias add. Currently only supports "Relu".
+ Computes a fused 2-D convolution scaled by conv_input_scale,
+ adds an optional side input scaled by side_input_scale, adds biases,
+ and applies ReLU. As an equation:
+ output = ReLU(conv_input_scale * Conv(conv_input, filter) +
+ side_input_scale * side_input + bias)
+ Note: In int8 mode, The ReLU will clip the output to the range [0..127].
Args:
- input_tensor: A `Tensor`. Must be one of the following types: `float32`.
- filter_tensor: A `Tensor`. Must have the same type as `input`.
- bias: A `Tensor`. Must have the same type as `input`.
- strides: A list of `ints`.
+ conv_input: A `Tensor` of the format specified by `data_format`.
+ filter: A `Tensor` whose format depends on `data_format`:
+ if `data_format` is "NCHW_VECT_C", filter should be "OIHW_VECT_I"
+ otherwise, it should be "HWIO" format.
+ bias: A 1-D `Tensor` of type `float32`, and dimensions equal to the
+ number of output channels.
+ strides: A list of 4 `ints` specifying convolution strides.
+ if `data_format` is "NCHW" or "NCHW_VECT_C", the order should be NCHW.
+ if `data_format` is "NHWC", the order should be NHWC.
padding: A `string` from: `"SAME", "VALID"`.
- activation_mode: A `string` from: `"Sigmoid", "Relu", "Relu6", "ReluX",
- "Tanh", "BandPass"`.
- data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
- `"NHWC"`.
+ conv_input_scale: A scalar `float32` that will be multiplied by conv_input.
+ This is optional and defaults to 1. However it should be set to
+ specify the quantization scale when `data_format` is "NCHW_VECT_C".
+ side_input_scale: A scalar `float32` that will be multiplied by side_input.
+ This is optional and defaults to 0.
+ side_input: A `Tensor` of the format specified by `data_format`.
+ This is useful for imlementing ResNet blocks.
+ activation_mode: (optional) currently must be the default "Relu".
+ Note that in qint8 mode, it also clips to 127, so acts like ReluX.
+ data_format: Specifies the data format.
+ Possible values are:
+ "NHWC" float [batch, height, width, channels]
+ "NCHW" float [batch, channels, height, width]
+ "NCHW_VECT_C" qint8 [batch, channels / 4, height, width, channels % 4]
+ Defaults to `"NHWC"`.
+ Performance is worst for `"NHWC"` and best for `"NCHW_VECT_C"`.
+ filter_format: Specifies the filter format.
+ Possible values are:
+ "HWIO" float [kernel_height, kernel_width, input_channels,
+ output_channels ]
+ "OIHW" float [output_channels, input_channels, kernel_height,
+ kernel_width ]
+ "OIHW_VECT_I" qint8 [ output_channels, input_channels / 4,
+ kernel_height, kernel_width, input_channels % 4 ]
+ Defaults to `"HWIO"`.
name: A name for the operation (optional).
Returns:
- A `Tensor`. Has the same type as `input`.
+ A `Tensor` of the format specified by `data_format`.
"""
+ if strides is None:
+ strides = [1, 1, 1, 1]
+ if side_input is None:
+ side_input = []
return gen_fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- input=input_tensor,
- filter=filter_tensor,
- bias=bias,
- strides=strides,
+ conv_input,
+ filter,
+ bias,
padding=padding,
+ strides=strides,
+ conv_input_scale=conv_input_scale,
+ side_input_scale=side_input_scale,
+ side_input=side_input,
activation_mode=activation_mode,
data_format=data_format,
+ filter_format=filter_format,
name=name)
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 5d6a2fa3b8..3b8f7d6ed7 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -19,13 +19,16 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+
from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@@ -484,7 +487,8 @@ class FusedConv2DBiasActivationTest(test.TestCase):
with self.test_session() as sess:
# Illegal strides.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "strides in the batch and depth"):
+ "Convolutional strides are not supported in "
+ "the batch or depth dimensions."):
sess.run(
fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
array_ops.placeholder(dtypes.float32),
@@ -494,7 +498,8 @@ class FusedConv2DBiasActivationTest(test.TestCase):
padding="SAME",
activation_mode="Relu"))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "strides in the batch and depth"):
+ "Convolutional strides are not supported in "
+ "the batch or depth dimensions."):
sess.run(
fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
array_ops.placeholder(dtypes.float32),
@@ -552,6 +557,286 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
return Test
+def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type):
+ """Calculates the size of an output dimension of a strided convolution.
+
+ Given the sizes of the corresponding dimension of the input and filter shapes,
+ and the stride and padding_types, calculates the size of the output dimension.
+ This function can be called separately for each input dimension.
+
+ Args:
+ input_dim: An `int` specifying the size of the input dimension.
+ filter_dim: An `int` specifying the size of the filter dimension.
+ stride: An `int` specifying the step size of the convolution along the
+ input dimension.
+ padding_type: either 'VALID' or 'SAME'.
+
+ Returns:
+ The size of the output dimension.
+ """
+ if padding_type == "VALID":
+ return (input_dim - filter_dim + stride) // stride
+ else: # padding_type == 'SAME'
+ return (input_dim + stride - 1) // stride
+
+
+def NchwVectCToNchw(in_tensor):
+ # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
+ t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
+ n = in_tensor.shape.dims[0].value
+ c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
+ h = in_tensor.shape.dims[2].value
+ w = in_tensor.shape.dims[3].value
+ return array_ops.reshape(t, [n, c, h, w])
+
+
+def OihwVectIToHwio(in_tensor):
+ # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
+ t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
+ o = in_tensor.shape.dims[0].value
+ i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
+ h = in_tensor.shape.dims[2].value
+ w = in_tensor.shape.dims[3].value
+ return array_ops.reshape(t, [h, w, i, o])
+
+
+def NchwToNchwVectC(in_tensor):
+ n, c, h, w = in_tensor.shape.as_list()
+ assert c % 4 == 0
+ t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
+ return array_ops.transpose(t, [0, 1, 3, 4, 2])
+
+
+def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
+ padding, strides, side_input_scale,
+ side_input, biases):
+ """Simulates the int8 fused 2-D convolution op using separate float ops.
+
+ The arguments and return values have the same format, meanings and
+ restrictions as the actual op.
+ Args:
+ conv_input_scale: A scalar 'float'.
+ conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
+ padding: A `string` from: `"SAME", "VALID"`.
+ strides: A list of `ints`.
+ side_input_scale: A scalar 'float'.
+ side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ biases: A `Tensor` of type `float32` in NCHW layout.
+ Returns:
+ A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ """
+ conv_result = nn_ops.conv2d(
+ NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
+ OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
+ strides=strides,
+ padding=padding,
+ data_format="NCHW") * conv_input_scale
+
+ conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw(
+ gen_array_ops.dequantize(side_input, -128, 127))
+
+ logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
+
+ result, _, _ = gen_array_ops.quantize_v2(
+ NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8)
+ return result
+
+
+class FusedConvInt8Tests(test.TestCase):
+ _test_params = [
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "VALID"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "VALID"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 16,
+ "output_channels": 16,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.001,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 5,
+ "filter_width": 5,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.001,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 7,
+ "filter_width": 1,
+ "vertical_stride": 2,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 1,
+ "filter_width": 7,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ ]
+
+ def runTest(self, test_param):
+ batch_size = test_param["batch_size"]
+ input_channels = test_param["input_channels"]
+ output_channels = test_param["output_channels"]
+ input_height = test_param["input_height"]
+ input_width = test_param["input_width"]
+ filter_height = test_param["filter_height"]
+ filter_width = test_param["filter_width"]
+ vertical_stride = test_param["vertical_stride"]
+ horizontal_stride = test_param["horizontal_stride"]
+ conv_input_scale = test_param["conv_input_scale"]
+ side_input_scale = test_param["side_input_scale"]
+ bias_scale = test_param["bias_scale"]
+ padding_type = test_param["padding_type"]
+
+ conv_input, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform(
+ [batch_size, input_channels // 4, input_height, input_width, 4],
+ minval=-0.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
+
+ kernel, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform(
+ [
+ output_channels, input_channels // 4, filter_height,
+ filter_width, 4
+ ],
+ minval=-1.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
+
+ output_height = CalculateCovolvedOutputDim(input_height, filter_height,
+ vertical_stride, padding_type)
+ output_width = CalculateCovolvedOutputDim(input_width, filter_width,
+ horizontal_stride, padding_type)
+ print("output_height=", output_height, ", output_width=", output_width)
+
+ side_input, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform(
+ [batch_size, output_channels // 4, output_height, output_width, 4],
+ minval=0.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
+
+ biases = random_ops.random_uniform(
+ [output_channels],
+ minval=-10 * bias_scale,
+ maxval=20 * bias_scale,
+ dtype=dtypes.float32)
+
+ strides = [1, 1, vertical_stride, horizontal_stride]
+
+ actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ conv_input,
+ kernel,
+ biases,
+ strides=strides,
+ padding=padding_type,
+ conv_input_scale=conv_input_scale,
+ side_input_scale=side_input_scale,
+ side_input=side_input,
+ data_format="NCHW_VECT_C",
+ filter_format="OIHW_VECT_I")
+
+ expected = SimulateFusedConv2dBiasActivationInt8(
+ conv_input_scale, conv_input, kernel, padding_type, strides,
+ side_input_scale, side_input, biases)
+
+ with self.test_session(use_gpu=True) as sess:
+ actual_y, expected_y = sess.run([actual, expected])
+ print("actual_y = ", actual_y)
+ print("expected_y = ", expected_y)
+ self.assertTrue(np.array_equal(actual_y, expected_y))
+
+ def testFusedConvInt8(self):
+ if not test.is_gpu_available(
+ cuda_only=True, min_cuda_compute_capability=(6, 1)):
+ tf_logging.info("int8 test skipped because not run with --config=cuda or "
+ "no GPUs with compute capability >= 6.1 are available.")
+ return
+ for test_param in self._test_params:
+ self.runTest(test_param)
+
+
if __name__ == "__main__":
for index, (input_size_, filter_size_, output_size_, stride_,
padding_) in enumerate(GetShrunkInceptionShapes()):