aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD28
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h29
-rw-r--r--tensorflow/core/kernels/group_by_window_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/in_topk_op.cc52
-rw-r--r--tensorflow/core/kernels/quantization_utils.h4
-rw-r--r--tensorflow/core/kernels/relu_op.cc62
-rw-r--r--tensorflow/core/kernels/relu_op.h42
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h40
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc4
9 files changed, 242 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a493452777..f45bb72c38 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4307,9 +4307,11 @@ filegroup(
"cwise_op_invert.cc",
"cwise_op_isfinite.cc",
"cwise_op_less.cc",
+ "cwise_op_less_equal.cc",
"cwise_op_log.cc",
"cwise_op_logical_and.cc",
"cwise_op_logical_not.cc",
+ "cwise_op_logical_or.cc",
"cwise_op_maximum.cc",
"cwise_op_minimum.cc",
"cwise_op_mul_1.cc",
@@ -4534,6 +4536,32 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "android_tensorflow_image_op",
+ srcs = [
+ "decode_image_op.cc",
+ ],
+ copts = tf_copts(),
+ linkopts = select({
+ "//tensorflow:android": [
+ "-ldl",
+ ],
+ "//conditions:default": [],
+ }),
+ tags = [
+ "manual",
+ "notap",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:android_gif_internal",
+ "//tensorflow/core:android_jpeg_internal",
+ "//tensorflow/core:android_png_internal",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ alwayslink = 1,
+)
+
# Quantization-specific OpKernels
tf_kernel_library(
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h
index 54b080c83b..4aefe26c54 100644
--- a/tensorflow/core/kernels/dense_update_functor.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -24,6 +24,9 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
enum DenseUpdateType { ADD, SUB, ASSIGN };
@@ -59,6 +62,32 @@ struct DenseUpdate<CPUDevice, T, ASSIGN> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct DenseUpdate<SYCLDevice, T, ADD> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) += update;
+ }
+};
+
+template <typename T>
+struct DenseUpdate<SYCLDevice, T, SUB> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) -= update;
+ }
+};
+
+template <typename T>
+struct DenseUpdate<SYCLDevice, T, ASSIGN> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) = update;
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc
index 948e83390e..94591a26af 100644
--- a/tensorflow/core/kernels/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc
@@ -42,7 +42,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- int64 window_size;
+ int64 window_size = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc
index 13890e5b7f..e2861ae090 100644
--- a/tensorflow/core/kernels/in_topk_op.cc
+++ b/tensorflow/core/kernels/in_topk_op.cc
@@ -17,11 +17,11 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
@@ -29,12 +29,29 @@ template <typename T, typename TARGET_T>
class InTopK : public OpKernel {
public:
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ if (context->num_inputs() == 2) {
+ OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ }
}
void Compute(OpKernelContext* context) override {
const auto& predictions_in = context->input(0);
const auto& targets_in = context->input(1);
+ int64 k_val = k_;
+ if (context->num_inputs() == 3) {
+ const auto& k_in = context->input(2);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
+ errors::InvalidArgument("k must be 0-D, got shape ",
+ k_in.shape().DebugString()));
+
+ if (k_in.dtype() == DT_INT32) {
+ k_val = k_in.scalar<int32>()();
+ } else {
+ k_val = k_in.scalar<int64>()();
+ }
+ }
+
OP_REQUIRES(context, predictions_in.dims() == 2,
errors::InvalidArgument("predictions must be 2-dimensional"));
OP_REQUIRES(context, targets_in.dims() == 1,
@@ -73,7 +90,7 @@ class InTopK : public OpKernel {
}
}
}
- out(b) = cannot_say ? false : (more_probable_classes < k_);
+ out(b) = cannot_say ? false : (more_probable_classes < k_val);
}
}
@@ -82,10 +99,35 @@ class InTopK : public OpKernel {
};
REGISTER_KERNEL_BUILDER(
- Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
+ Name("InTopK").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("precision")
+ .TypeConstraint<int32>("T"),
+ InTopK<float, int32>);
+REGISTER_KERNEL_BUILDER(
+ Name("InTopK").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("precision")
+ .TypeConstraint<int64>("T"),
+ InTopK<float, int64>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("InTopKV2").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("k")
+ .HostMemory("precision")
+ .TypeConstraint<int32>("T"),
InTopK<float, int32>);
REGISTER_KERNEL_BUILDER(
- Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int64>("T"),
+ Name("InTopKV2").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("k")
+ .HostMemory("precision")
+ .TypeConstraint<int64>("T"),
InTopK<float, int64>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h
index cb4fcbd788..c5dc2e7194 100644
--- a/tensorflow/core/kernels/quantization_utils.h
+++ b/tensorflow/core/kernels/quantization_utils.h
@@ -823,9 +823,9 @@ void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device,
const int64 input_element_count = input.NumElements();
const int64 smaller_input_element_count = smaller_input.NumElements();
- QuantizedToFloatStruct<T1> smaller_input_q2f(smaller_input_min,
+ QuantizedToFloatStruct<T1> input_q2f(input_min, input_max);
+ QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min,
smaller_input_max);
- QuantizedToFloatStruct<T2> input_q2f(input_min, input_max);
FloatToQuantizedStruct<T3> f2q(*output_min, *output_max);
auto smaller_input_float =
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index d8d30e87e2..afad288cc0 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -50,15 +50,21 @@ typedef Eigen::SyclDevice SYCLDevice;
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
-#define REGISTER_ELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- EluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- EluGradOp<CPUDevice, type>)
-
-// Elu only makes sense with float or double.
+#define REGISTER_ELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ EluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ EluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SeluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SeluGradOp<CPUDevice, type>)
+
+// Elu and Selu only make sense with float or double.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
#undef REGISTER_ELU_KERNELS
@@ -103,7 +109,23 @@ namespace functor {
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor activations, \
typename TTypes<T>::Tensor backprops); \
- extern template struct EluGrad<GPUDevice, T>;
+ extern template struct EluGrad<GPUDevice, T>; \
+ \
+ template <> \
+ void Selu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, \
+ typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct Selu<GPUDevice, T>; \
+ \
+ template <> \
+ void SeluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor activations, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct SeluGrad<GPUDevice, T>;
+
+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
@@ -127,7 +149,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
EluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>)
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluGradOp<GPUDevice, type>)
+
+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
@@ -154,7 +184,15 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
EluOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>)
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluGradOp<SYCLDevice, type>)
+
+
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 365c6201a5..e712b02bd7 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -173,6 +173,48 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
output->flat<T>());
}
+template <typename Device, typename T>
+class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Selu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (outputs): outputs of the SeluOp()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, output);
+ }
+};
+
+template <typename Device, typename T>
+void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::SeluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
} // namespace tensorflow
#undef EIGEN_USE_THREADS
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index 633522920c..9577b963c6 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -125,6 +125,46 @@ struct EluGrad {
}
};
+// Functor used by SeluOp to do the computations.
+template <typename Device, typename T>
+struct Selu {
+ // Computes Selu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ // features.constant(?)
+ const auto scale = static_cast<T>(1.0507009873554804934193349852946);
+ const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
+ const auto one = static_cast<T>(1);
+ const auto zero = static_cast<T>(0);
+ activations.device(d) =
+ (features < zero)
+ .select(scale_alpha * (features.exp() - features.constant(one)),
+ scale * features);
+ }
+};
+
+// Functor used by SeluGradOp to do the computations.
+template <typename Device, typename T>
+struct SeluGrad {
+ // Computes SeluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Selu op.
+ // activations: outputs of the Selu op.
+ // backprops: gradients to backpropagate to the Selu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor activations,
+ typename TTypes<T>::Tensor backprops) {
+ const auto scale = static_cast<T>(1.0507009873554804934193349852946);
+ const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
+ backprops.device(d) =
+ (activations < static_cast<T>(0)).select(
+ gradients * (activations + scale_alpha), gradients * scale);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index 30c4a289f7..ec09d8dfea 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -35,7 +35,9 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::Relu6<GPUDevice, T>; \
template struct functor::Relu6Grad<GPUDevice, T>; \
template struct functor::Elu<GPUDevice, T>; \
- template struct functor::EluGrad<GPUDevice, T>;
+ template struct functor::EluGrad<GPUDevice, T>; \
+ template struct functor::Selu<GPUDevice, T>; \
+ template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);