aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:29:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:29:04 -0700
commit96237f7b7ae6b7b8a2cbcf6d64312906b96f060b (patch)
treea96bb853e59dc37e90e4f8fde229f4d88b3f225a /tensorflow/core
parent3f0155133d668cf6cee1f1fb362d2a75c04836e3 (diff)
parent96eec07af06f4dfc75cee57b74ba4b5347619634 (diff)
Merge pull request #21658 from lowintelligence:master
PiperOrigin-RevId: 216217509
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt24
-rw-r--r--tensorflow/core/kernels/relu_op.cc153
-rw-r--r--tensorflow/core/kernels/relu_op.h61
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h30
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/ops/nn_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt68
8 files changed, 305 insertions, 69 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
new file mode 100644
index 0000000000..280148e032
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "LeakyRelu"
+ visibility: HIDDEN
+ summary: "Computes rectified linear: `max(features, features * alpha)`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
new file mode 100644
index 0000000000..e427526602
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LeakyReluGrad"
+ visibility: HIDDEN
+ in_arg {
+ name: "gradients"
+ description: <<END
+The backpropagated gradients to the corresponding LeakyRelu operation.
+END
+ }
+ in_arg {
+ name: "features"
+ description: <<END
+The features passed as input to the corresponding LeakyRelu operation,
+OR the outputs of that operation (both work equivalently).
+END
+ }
+ out_arg {
+ name: "backprops"
+ description: <<END
+`gradients * (features > 0) + alpha * gradients * (featurs <= 0)`.
+END
+ }
+ summary: "Computes rectified linear gradients for a LeakyRelu operation."
+}
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index 173fea37ed..e67695d54a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
-#define REGISTER_RELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluGradOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6Op<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6GradOp<CPUDevice, type>)
+#define REGISTER_RELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6Op<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<CPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
@@ -99,6 +105,19 @@ namespace functor {
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
+ void LeakyRelu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct LeakyRelu<GPUDevice, T>; \
+ \
+ template <> \
+ void LeakyReluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct LeakyReluGrad<GPUDevice, T>; \
+ \
+ template <> \
void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
@@ -134,30 +153,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6Op<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6GradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- SeluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6Op<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
@@ -188,30 +213,36 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
// Registration of the GPU implementations.
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6Op<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6GradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- SeluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6Op<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6GradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluGradOp<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 4775deeb61..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
}
template <typename Device, typename T>
+class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
+ public:
+ explicit LeakyReluOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::LeakyRelu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
+ output->flat<T>());
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+class LeakyReluGradOp
+ : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
+ public:
+ explicit LeakyReluGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, T alpha, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, alpha_, output);
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a, T alpha,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::LeakyReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
+ output->flat<T>());
+};
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index e564da335a..f917142a12 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -91,6 +91,36 @@ struct Relu6Grad {
}
};
+// Functor used by LeakyReluOp to do the computations.
+template <typename Device, typename T>
+struct LeakyRelu {
+ // Computes LeakyRelu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ T alpha, typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(features * alpha);
+ }
+};
+
+// Functor used by LeakyReluGradOp to do the computations.
+template <typename Device, typename T>
+struct LeakyReluGrad {
+ // Computes LeakyReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the LeakyRelu op.
+ // features: either the inputs that were passed to the LeakyRelu or, or its
+ // outputs (using either one yields the same result here).
+ // backprops: gradients to backpropagate to the LeakyRelu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features, T alpha,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ (features > static_cast<T>(0)).select(gradients, gradients * alpha);
+ }
+};
+
// Functor used by EluOp to do the computations.
template <typename Device, typename T>
struct Elu {
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index b9391517c1..dd5f9495e2 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -145,14 +145,16 @@ struct Relu<Device, qint8> {
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::Relu<GPUDevice, T>; \
- template struct functor::ReluGrad<GPUDevice, T>; \
- template struct functor::Relu6<GPUDevice, T>; \
- template struct functor::Relu6Grad<GPUDevice, T>; \
- template struct functor::Elu<GPUDevice, T>; \
- template struct functor::EluGrad<GPUDevice, T>; \
- template struct functor::Selu<GPUDevice, T>; \
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Relu<GPUDevice, T>; \
+ template struct functor::ReluGrad<GPUDevice, T>; \
+ template struct functor::Relu6<GPUDevice, T>; \
+ template struct functor::Relu6Grad<GPUDevice, T>; \
+ template struct functor::LeakyRelu<GPUDevice, T>; \
+ template struct functor::LeakyReluGrad<GPUDevice, T>; \
+ template struct functor::Elu<GPUDevice, T>; \
+ template struct functor::EluGrad<GPUDevice, T>; \
+ template struct functor::Selu<GPUDevice, T>; \
template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d1d81b27cc..a9ca69ad86 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+REGISTER_OP("LeakyRelu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("LeakyReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 14cc9df9a2..2048ad26ac 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14296,6 +14296,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"