diff options
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/BUILD | 28 | ||||
-rw-r--r-- | tensorflow/core/kernels/dense_update_functor.h | 29 | ||||
-rw-r--r-- | tensorflow/core/kernels/group_by_window_dataset_op.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/in_topk_op.cc | 52 | ||||
-rw-r--r-- | tensorflow/core/kernels/quantization_utils.h | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op.cc | 62 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op.h | 42 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op_functor.h | 40 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op_gpu.cu.cc | 4 |
9 files changed, 242 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a493452777..f45bb72c38 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4307,9 +4307,11 @@ filegroup( "cwise_op_invert.cc", "cwise_op_isfinite.cc", "cwise_op_less.cc", + "cwise_op_less_equal.cc", "cwise_op_log.cc", "cwise_op_logical_and.cc", "cwise_op_logical_not.cc", + "cwise_op_logical_or.cc", "cwise_op_maximum.cc", "cwise_op_minimum.cc", "cwise_op_mul_1.cc", @@ -4534,6 +4536,32 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "android_tensorflow_image_op", + srcs = [ + "decode_image_op.cc", + ], + copts = tf_copts(), + linkopts = select({ + "//tensorflow:android": [ + "-ldl", + ], + "//conditions:default": [], + }), + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:android_gif_internal", + "//tensorflow/core:android_jpeg_internal", + "//tensorflow/core:android_png_internal", + "//tensorflow/core:android_tensorflow_lib_lite", + ], + alwayslink = 1, +) + # Quantization-specific OpKernels tf_kernel_library( diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h index 54b080c83b..4aefe26c54 100644 --- a/tensorflow/core/kernels/dense_update_functor.h +++ b/tensorflow/core/kernels/dense_update_functor.h @@ -24,6 +24,9 @@ limitations under the License. namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL enum DenseUpdateType { ADD, SUB, ASSIGN }; @@ -59,6 +62,32 @@ struct DenseUpdate<CPUDevice, T, ASSIGN> { } }; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct DenseUpdate<SYCLDevice, T, ADD> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params, + typename TTypes<T>::ConstFlat update) { + params.device(d) += update; + } +}; + +template <typename T> +struct DenseUpdate<SYCLDevice, T, SUB> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params, + typename TTypes<T>::ConstFlat update) { + params.device(d) -= update; + } +}; + +template <typename T> +struct DenseUpdate<SYCLDevice, T, ASSIGN> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params, + typename TTypes<T>::ConstFlat update) { + params.device(d) = update; + } +}; +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc index 948e83390e..94591a26af 100644 --- a/tensorflow/core/kernels/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc @@ -42,7 +42,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 window_size; + int64 window_size = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size)); OP_REQUIRES( diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc index 13890e5b7f..e2861ae090 100644 --- a/tensorflow/core/kernels/in_topk_op.cc +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -17,11 +17,11 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -29,12 +29,29 @@ template <typename T, typename TARGET_T> class InTopK : public OpKernel { public: explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + if (context->num_inputs() == 2) { + OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + } } void Compute(OpKernelContext* context) override { const auto& predictions_in = context->input(0); const auto& targets_in = context->input(1); + int64 k_val = k_; + if (context->num_inputs() == 3) { + const auto& k_in = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), + errors::InvalidArgument("k must be 0-D, got shape ", + k_in.shape().DebugString())); + + if (k_in.dtype() == DT_INT32) { + k_val = k_in.scalar<int32>()(); + } else { + k_val = k_in.scalar<int64>()(); + } + } + OP_REQUIRES(context, predictions_in.dims() == 2, errors::InvalidArgument("predictions must be 2-dimensional")); OP_REQUIRES(context, targets_in.dims() == 1, @@ -73,7 +90,7 @@ class InTopK : public OpKernel { } } } - out(b) = cannot_say ? false : (more_probable_classes < k_); + out(b) = cannot_say ? false : (more_probable_classes < k_val); } } @@ -82,10 +99,35 @@ class InTopK : public OpKernel { }; REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int32>("T"), + Name("InTopK").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint<int32>("T"), + InTopK<float, int32>); +REGISTER_KERNEL_BUILDER( + Name("InTopK").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint<int64>("T"), + InTopK<float, int64>); + +REGISTER_KERNEL_BUILDER( + Name("InTopKV2").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint<int32>("T"), InTopK<float, int32>); REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int64>("T"), + Name("InTopKV2").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint<int64>("T"), InTopK<float, int64>); } // namespace tensorflow diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index cb4fcbd788..c5dc2e7194 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -823,9 +823,9 @@ void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device, const int64 input_element_count = input.NumElements(); const int64 smaller_input_element_count = smaller_input.NumElements(); - QuantizedToFloatStruct<T1> smaller_input_q2f(smaller_input_min, + QuantizedToFloatStruct<T1> input_q2f(input_min, input_max); + QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min, smaller_input_max); - QuantizedToFloatStruct<T2> input_q2f(input_min, input_max); FloatToQuantizedStruct<T3> f2q(*output_min, *output_max); auto smaller_input_float = diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index d8d30e87e2..afad288cc0 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -50,15 +50,21 @@ typedef Eigen::SyclDevice SYCLDevice; TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); #undef REGISTER_RELU_KERNELS -#define REGISTER_ELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - EluOp<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - EluGradOp<CPUDevice, type>) - -// Elu only makes sense with float or double. +#define REGISTER_ELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + EluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + EluGradOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SeluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SeluGradOp<CPUDevice, type>) + +// Elu and Selu only make sense with float or double. TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS); #undef REGISTER_ELU_KERNELS @@ -103,7 +109,23 @@ namespace functor { const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ typename TTypes<T>::ConstTensor activations, \ typename TTypes<T>::Tensor backprops); \ - extern template struct EluGrad<GPUDevice, T>; + extern template struct EluGrad<GPUDevice, T>; \ + \ + template <> \ + void Selu<GPUDevice, T>::operator()( \ + const GPUDevice& d, \ + typename TTypes<T>::ConstTensor features, \ + typename TTypes<T>::Tensor activations); \ + extern template struct Selu<GPUDevice, T>; \ + \ + template <> \ + void SeluGrad<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ + typename TTypes<T>::ConstTensor activations, \ + typename TTypes<T>::Tensor backprops); \ + extern template struct SeluGrad<GPUDevice, T>; + + TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor @@ -127,7 +149,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); EluOp<GPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - EluGradOp<GPUDevice, type>) + EluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SeluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SeluGradOp<GPUDevice, type>) + + TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS @@ -154,7 +184,15 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); EluOp<SYCLDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - EluGradOp<SYCLDevice, type>) + EluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + SeluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + SeluGradOp<SYCLDevice, type>) + + TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); #undef REGISTER_SYCL_KERNELS diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index 365c6201a5..e712b02bd7 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -173,6 +173,48 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, output->flat<T>()); } +template <typename Device, typename T> +class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> { + public: + using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Selu<Device, T> functor; + functor(context->eigen_device<Device>(), input.flat<T>(), + output->flat<T>()); + } +}; + +template <typename Device, typename T> +class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> { + public: + using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (outputs): outputs of the SeluOp() + // OUTPUT: + // gradients to backprop + template <int NDIMS> + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template <typename Device, typename T> +void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::SeluGrad<Device, T> functor; + functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), + output->flat<T>()); +} + } // namespace tensorflow #undef EIGEN_USE_THREADS diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 633522920c..9577b963c6 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -125,6 +125,46 @@ struct EluGrad { } }; +// Functor used by SeluOp to do the computations. +template <typename Device, typename T> +struct Selu { + // Computes Selu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes<T>::ConstTensor features, + typename TTypes<T>::Tensor activations) { + // features.constant(?) + const auto scale = static_cast<T>(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123); + const auto one = static_cast<T>(1); + const auto zero = static_cast<T>(0); + activations.device(d) = + (features < zero) + .select(scale_alpha * (features.exp() - features.constant(one)), + scale * features); + } +}; + +// Functor used by SeluGradOp to do the computations. +template <typename Device, typename T> +struct SeluGrad { + // Computes SeluGrad backprops. + // + // gradients: gradients backpropagated to the Selu op. + // activations: outputs of the Selu op. + // backprops: gradients to backpropagate to the Selu inputs. + void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients, + typename TTypes<T>::ConstTensor activations, + typename TTypes<T>::Tensor backprops) { + const auto scale = static_cast<T>(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123); + backprops.device(d) = + (activations < static_cast<T>(0)).select( + gradients * (activations + scale_alpha), gradients * scale); + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 30c4a289f7..ec09d8dfea 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -35,7 +35,9 @@ typedef Eigen::GpuDevice GPUDevice; template struct functor::Relu6<GPUDevice, T>; \ template struct functor::Relu6Grad<GPUDevice, T>; \ template struct functor::Elu<GPUDevice, T>; \ - template struct functor::EluGrad<GPUDevice, T>; + template struct functor::EluGrad<GPUDevice, T>; \ + template struct functor::Selu<GPUDevice, T>; \ + template struct functor::SeluGrad<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); |