aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bias_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bias_op.cc')
-rw-r--r--tensorflow/core/kernels/bias_op.cc55
1 files changed, 40 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 10f5d4ce85..b3a77d1caa 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -35,14 +35,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
-class BiasOp;
-
-template <typename T>
-class BiasOp<CPUDevice, T> : public BinaryOp<T> {
+class BiasOp : public BinaryOp<T> {
public:
- typedef CPUDevice Device;
explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
string data_format;
if (context->GetAttr("data_format", &data_format).ok()) {
@@ -52,7 +51,8 @@ class BiasOp<CPUDevice, T> : public BinaryOp<T> {
data_format_ = FORMAT_NHWC;
}
OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument("CPU BiasOp only supports NHWC."));
+ errors::InvalidArgument(context->device()->attributes().name() +
+ " BiasOp only supports NHWC."));
}
void Compute(OpKernelContext* context) override {
@@ -122,6 +122,21 @@ class BiasOp<CPUDevice, T> : public BinaryOp<T> {
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ BiasOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ BiasOp<SYCLDevice, type>);
+
+TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
namespace {
void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
@@ -165,12 +180,8 @@ struct AccumulatorType<Eigen::half> {
} // namespace
template <typename Device, typename T>
-class BiasGradOp;
-
-template <typename T>
-class BiasGradOp<CPUDevice, T> : public OpKernel {
+class BiasGradOp : public OpKernel {
public:
- typedef CPUDevice Device;
explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
string data_format;
if (context->GetAttr("data_format", &data_format).ok()) {
@@ -180,7 +191,8 @@ class BiasGradOp<CPUDevice, T> : public OpKernel {
data_format_ = FORMAT_NHWC;
}
OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument("CPU BiasGradOp only supports NHWC."));
+ errors::InvalidArgument(context->device()->attributes().name() +
+ " BiasGradOp only supports NHWC."));
}
void Compute(OpKernelContext* context) override {
@@ -192,8 +204,9 @@ class BiasGradOp<CPUDevice, T> : public OpKernel {
output_backprop.shape().DebugString()));
OP_REQUIRES(
- context, FastBoundsCheck(output_backprop.NumElements(),
- std::numeric_limits<int32>::max()),
+ context,
+ FastBoundsCheck(output_backprop.NumElements(),
+ std::numeric_limits<int32>::max()),
errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
int32 batch, height, width, channel;
@@ -215,7 +228,7 @@ class BiasGradOp<CPUDevice, T> : public OpKernel {
#else
Eigen::array<int, 1> reduction_axis = {0};
#endif
- output->template flat<T>().device(context->eigen_device<CPUDevice>()) =
+ output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
.template cast<typename AccumulatorType<T>::type>()
.reshape(two_dims)
@@ -237,6 +250,18 @@ class BiasGradOp<CPUDevice, T> : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ BiasGradOp<SYCLDevice, type>);
+
+TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
template <typename T>
class BiasOp<GPUDevice, T> : public BinaryOp<T> {