/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS #ifdef GOOGLE_CUDA #define EIGEN_USE_GPU #endif // GOOGLE_CUDA #include "tensorflow/core/kernels/fake_quant_ops_functor.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/protobuf.h" using tensorflow::BinaryElementWiseOp; using tensorflow::DEVICE_CPU; #if GOOGLE_CUDA using tensorflow::DEVICE_GPU; #endif using tensorflow::OpKernel; using tensorflow::OpKernelConstruction; using tensorflow::OpKernelContext; using tensorflow::Tensor; using tensorflow::TensorShape; using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove. using tensorflow::UnaryElementWiseOp; using tensorflow::errors::InvalidArgument; namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; namespace { bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; } } // namespace // ----------------------------------------------------------------------------- // Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in // core/ops/array_ops.cc. template class FakeQuantWithMinMaxArgsOp : public UnaryElementWiseOp> { public: typedef UnaryElementWiseOp> Base; explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context) : Base::UnaryElementWiseOp(context) { OP_REQUIRES_OK(context, context->GetAttr("min", &min_)); OP_REQUIRES_OK(context, context->GetAttr("max", &max_)); OP_REQUIRES(context, min_ < max_, InvalidArgument("min has to be smaller than max, was: ", min_, " >= ", max_)); int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES( context, IsNumBitsValid(num_bits), InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; quant_max_ = (1 << num_bits) - 1; } void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { FakeQuantWithMinMaxArgsFunctor functor; functor(context->eigen_device(), input.flat(), min_, max_, quant_min_, quant_max_, output->flat()); } private: float min_; float max_; int quant_min_; int quant_max_; }; // Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in // core/ops/array_ops.cc. template class FakeQuantWithMinMaxArgsGradientOp : public BinaryElementWiseOp> { public: typedef BinaryElementWiseOp> Base; explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context) : Base::BinaryElementWiseOp(context) { OP_REQUIRES_OK(context, context->GetAttr("min", &min_)); OP_REQUIRES_OK(context, context->GetAttr("max", &max_)); OP_REQUIRES(context, min_ < max_, InvalidArgument("min has to be smaller than max, was: ", min_, " >= ", max_)); int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES( context, IsNumBitsValid(num_bits), InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; quant_max_ = (1 << num_bits) - 1; } template void Operate(OpKernelContext* context, const Tensor& gradient, const Tensor& input, Tensor* output) { OperateNoTemplate(context, gradient, input, output); } void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient, const Tensor& input, Tensor* output) { OP_REQUIRES(context, input.IsSameSize(gradient), InvalidArgument("gradient and input must be the same size")); FakeQuantWithMinMaxArgsGradientFunctor functor; functor(context->eigen_device(), gradient.flat(), input.flat(), min_, max_, quant_min_, quant_max_, output->flat()); } private: float min_; float max_; int quant_min_; int quant_max_; }; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU), FakeQuantWithMinMaxArgsOp); REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU), FakeQuantWithMinMaxArgsGradientOp); #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; // Forward declarations for functor specializations for GPU. template <> void FakeQuantWithMinMaxArgsFunctor::operator()( const GPUDevice& d, typename TTypes::ConstFlat inputs, const float min, const float max, const int quant_min, const int quant_max, typename TTypes::Flat outputs); extern template struct FakeQuantWithMinMaxArgsFunctor; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU), FakeQuantWithMinMaxArgsOp); template <> void FakeQuantWithMinMaxArgsGradientFunctor::operator()( const GPUDevice& d, typename TTypes::ConstFlat gradients, typename TTypes::ConstFlat inputs, const float min, const float max, const int quant_min, const int quant_max, typename TTypes::Flat backprops); REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU), FakeQuantWithMinMaxArgsGradientOp); #endif // GOOGLE_CUDA // ----------------------------------------------------------------------------- // Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in // core/ops/array_ops.cc. template class FakeQuantWithMinMaxVarsOp : public OpKernel { public: explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context) : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES( context, IsNumBitsValid(num_bits), InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; quant_max_ = (1 << num_bits) - 1; } void Compute(OpKernelContext* context) override { CHECK_EQ(3, context->num_inputs()); const Tensor& input = context->input(0); const Tensor& min = context->input(1); const Tensor& max = context->input(2); Tensor* output; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); FakeQuantWithMinMaxVarsFunctor functor; functor(context->eigen_device(), input.flat(), min.scalar(), max.scalar(), quant_min_, quant_max_, output->flat()); } private: int quant_min_; int quant_max_; }; // Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in // core/ops/array_ops.cc. template class FakeQuantWithMinMaxVarsGradientOp : public OpKernel { public: explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context) : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES( context, IsNumBitsValid(num_bits), InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; quant_max_ = (1 << num_bits) - 1; } void Compute(OpKernelContext* context) override { CHECK_EQ(4, context->num_inputs()); const Tensor& gradient = context->input(0); const Tensor& input = context->input(1); OP_REQUIRES(context, input.IsSameSize(gradient), InvalidArgument("gradient and input must be the same size")); const Tensor& min = context->input(2); const Tensor& max = context->input(3); Tensor* grad_wrt_input; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &grad_wrt_input)); TensorShape scalar_shape; Tensor* grad_wrt_min; OP_REQUIRES_OK(context, context->allocate_output(1, scalar_shape, &grad_wrt_min)); Tensor* grad_wrt_max; OP_REQUIRES_OK(context, context->allocate_output(2, scalar_shape, &grad_wrt_max)); FakeQuantWithMinMaxVarsGradientFunctor functor; functor(context->eigen_device(), gradient.flat(), input.flat(), min.scalar(), max.scalar(), quant_min_, quant_max_, grad_wrt_input->flat(), grad_wrt_min->scalar(), grad_wrt_max->scalar()); } private: int quant_min_; int quant_max_; }; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU), FakeQuantWithMinMaxVarsOp); REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU), FakeQuantWithMinMaxVarsGradientOp); #if GOOGLE_CUDA template <> void FakeQuantWithMinMaxVarsFunctor::operator()( const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstScalar min, typename TTypes::ConstScalar max, const int quant_min, const int quant_max, typename TTypes::Flat output); extern template struct FakeQuantWithMinMaxVarsFunctor; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars") .Device(DEVICE_GPU) .HostMemory("min") .HostMemory("max"), FakeQuantWithMinMaxVarsOp); template <> void FakeQuantWithMinMaxVarsGradientFunctor::operator()( const GPUDevice& d, typename TTypes::ConstFlat gradients, typename TTypes::ConstFlat inputs, typename TTypes::ConstScalar min, typename TTypes::ConstScalar max, const int quant_min, const int quant_max, typename TTypes::Flat backprops_wrt_input, typename TTypes::Scalar backprop_wrt_min, typename TTypes::Scalar backprop_wrt_max); extern template struct FakeQuantWithMinMaxVarsGradientFunctor; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient") .Device(DEVICE_GPU) .HostMemory("min") .HostMemory("max"), FakeQuantWithMinMaxVarsGradientOp); #endif // GOOGLE_CUDA // ----------------------------------------------------------------------------- // Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation // in core/ops/array_ops.cc. template class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel { public: explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context) : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES( context, IsNumBitsValid(num_bits), InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; quant_max_ = (1 << num_bits) - 1; } void Compute(OpKernelContext* context) override { CHECK_EQ(3, context->num_inputs()); const Tensor& input = context->input(0); const int depth = input.dim_size(input.dims() - 1); // last dimension size. const Tensor& min = context->input(1); OP_REQUIRES(context, min.dim_size(0) == depth, InvalidArgument("min has incorrect size, expected ", depth, " was ", min.dim_size(0))); const Tensor& max = context->input(2); OP_REQUIRES(context, max.dim_size(0) == depth, InvalidArgument("max has incorrect size, expected ", depth, " was ", max.dim_size(0))); Tensor* output; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); FakeQuantWithMinMaxVarsPerChannelFunctor functor; functor(context->eigen_device(), input.flat_inner_dims(), min.vec(), max.vec(), quant_min_, quant_max_, output->flat_inner_dims()); } private: int quant_min_; int quant_max_; }; // Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its // documentation in core/ops/array_ops.cc. template class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel { public: explicit FakeQuantWithMinMaxVarsPerChannelGradientOp( OpKernelConstruction* context) : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES( context, IsNumBitsValid(num_bits), InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; quant_max_ = (1 << num_bits) - 1; } void Compute(OpKernelContext* context) override { CHECK_EQ(4, context->num_inputs()); const Tensor& gradient = context->input(0); const Tensor& input = context->input(1); OP_REQUIRES(context, input.IsSameSize(gradient), InvalidArgument("gradient and input must be the same size")); const int depth = input.dim_size(input.dims() - 1); // last dimension size. const Tensor& min = context->input(2); OP_REQUIRES(context, min.dim_size(0) == depth, InvalidArgument("min has incorrect size, expected ", depth, " was ", min.dim_size(0))); const Tensor& max = context->input(3); OP_REQUIRES(context, max.dim_size(0) == depth, InvalidArgument("max has incorrect size, expected ", depth, " was ", max.dim_size(0))); Tensor* grad_wrt_input; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &grad_wrt_input)); TensorShape min_max_shape({input.dim_size(input.dims() - 1)}); Tensor* grad_wrt_min; OP_REQUIRES_OK(context, context->allocate_output(1, min_max_shape, &grad_wrt_min)); Tensor* grad_wrt_max; OP_REQUIRES_OK(context, context->allocate_output(2, min_max_shape, &grad_wrt_max)); FakeQuantWithMinMaxVarsPerChannelGradientFunctor functor; functor( context->eigen_device(), gradient.flat_inner_dims(), input.flat_inner_dims(), min.vec(), max.vec(), quant_min_, quant_max_, grad_wrt_input->flat_inner_dims(), grad_wrt_min->vec(), grad_wrt_max->vec()); } private: int quant_min_; int quant_max_; }; REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxVarsPerChannel").Device(DEVICE_CPU), FakeQuantWithMinMaxVarsPerChannelOp); REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU), FakeQuantWithMinMaxVarsPerChannelGradientOp); #if GOOGLE_CUDA template <> void FakeQuantWithMinMaxVarsPerChannelFunctor::operator()( const GPUDevice& d, typename TTypes::ConstMatrix inputs, typename TTypes::ConstFlat min, typename TTypes::ConstFlat max, const int quant_min, const int quant_max, typename TTypes::Matrix outputs); extern template struct FakeQuantWithMinMaxVarsPerChannelFunctor; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel") .Device(DEVICE_GPU) .HostMemory("min") .HostMemory("max"), FakeQuantWithMinMaxVarsPerChannelOp); template <> void FakeQuantWithMinMaxVarsPerChannelGradientFunctor::operator()( const GPUDevice& d, typename TTypes::ConstMatrix gradients, typename TTypes::ConstMatrix inputs, typename TTypes::ConstVec min, typename TTypes::ConstVec max, const int quant_min, const int quant_max, typename TTypes::Matrix backprops_wrt_input, typename TTypes::Vec backprop_wrt_min, typename TTypes::Vec backprop_wrt_max); extern template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor< GPUDevice>; REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient") .Device(DEVICE_GPU) .HostMemory("min") .HostMemory("max"), FakeQuantWithMinMaxVarsPerChannelGradientOp); #endif // GOOGLE_CUDA } // namespace tensorflow