diff options
Diffstat (limited to 'tensorflow/core/kernels/bias_op.cc')
-rw-r--r-- | tensorflow/core/kernels/bias_op.cc | 55 |
1 files changed, 40 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 10f5d4ce85..b3a77d1caa 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -35,14 +35,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> -class BiasOp; - -template <typename T> -class BiasOp<CPUDevice, T> : public BinaryOp<T> { +class BiasOp : public BinaryOp<T> { public: - typedef CPUDevice Device; explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) { string data_format; if (context->GetAttr("data_format", &data_format).ok()) { @@ -52,7 +51,8 @@ class BiasOp<CPUDevice, T> : public BinaryOp<T> { data_format_ = FORMAT_NHWC; } OP_REQUIRES(context, data_format_ == FORMAT_NHWC, - errors::InvalidArgument("CPU BiasOp only supports NHWC.")); + errors::InvalidArgument(context->device()->attributes().name() + + " BiasOp only supports NHWC.")); } void Compute(OpKernelContext* context) override { @@ -122,6 +122,21 @@ class BiasOp<CPUDevice, T> : public BinaryOp<T> { TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); #undef REGISTER_KERNEL +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + BiasOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + BiasOp<SYCLDevice, type>); + +TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL +#endif // TENSORFLOW_USE_SYCL + namespace { void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format, @@ -165,12 +180,8 @@ struct AccumulatorType<Eigen::half> { } // namespace template <typename Device, typename T> -class BiasGradOp; - -template <typename T> -class BiasGradOp<CPUDevice, T> : public OpKernel { +class BiasGradOp : public OpKernel { public: - typedef CPUDevice Device; explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) { string data_format; if (context->GetAttr("data_format", &data_format).ok()) { @@ -180,7 +191,8 @@ class BiasGradOp<CPUDevice, T> : public OpKernel { data_format_ = FORMAT_NHWC; } OP_REQUIRES(context, data_format_ == FORMAT_NHWC, - errors::InvalidArgument("CPU BiasGradOp only supports NHWC.")); + errors::InvalidArgument(context->device()->attributes().name() + + " BiasGradOp only supports NHWC.")); } void Compute(OpKernelContext* context) override { @@ -192,8 +204,9 @@ class BiasGradOp<CPUDevice, T> : public OpKernel { output_backprop.shape().DebugString())); OP_REQUIRES( - context, FastBoundsCheck(output_backprop.NumElements(), - std::numeric_limits<int32>::max()), + context, + FastBoundsCheck(output_backprop.NumElements(), + std::numeric_limits<int32>::max()), errors::InvalidArgument("BiasGrad requires tensor size <= int32 max")); int32 batch, height, width, channel; @@ -215,7 +228,7 @@ class BiasGradOp<CPUDevice, T> : public OpKernel { #else Eigen::array<int, 1> reduction_axis = {0}; #endif - output->template flat<T>().device(context->eigen_device<CPUDevice>()) = + output->template flat<T>().device(context->eigen_device<Device>()) = output_backprop.flat<T>() .template cast<typename AccumulatorType<T>::type>() .reshape(two_dims) @@ -237,6 +250,18 @@ class BiasGradOp<CPUDevice, T> : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); #undef REGISTER_KERNEL +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + BiasGradOp<SYCLDevice, type>); + +TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA template <typename T> class BiasOp<GPUDevice, T> : public BinaryOp<T> { |