aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-20 22:25:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-20 22:29:05 -0700
commitec0e1e580c1eb46afd5a81af8f925d8813e7ab50 (patch)
tree4fdf8397f5eb7166e351e9fe2b8d4f79dd4bb452
parentbadd5456977e2b981a08cd5d6e41a292ea6eafda (diff)
Automated g4 rollback of changelist 165773305
PiperOrigin-RevId: 165887626
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake2
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/kernels/BUILD9
-rw-r--r--tensorflow/core/kernels/l2loss_op.cc39
-rw-r--r--tensorflow/core/kernels/l2loss_op.h16
-rw-r--r--tensorflow/core/kernels/l2loss_op_gpu.cu.cc49
-rw-r--r--tensorflow/core/kernels/reduction_ops.h3
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h15
-rw-r--r--tensorflow/core/kernels/reduction_ops_gpu.cu.cc210
-rw-r--r--tensorflow/core/kernels/reduction_ops_gpu_kernels.h697
-rw-r--r--tensorflow/core/kernels/reduction_ops_test.cc163
-rw-r--r--tensorflow/core/util/permutation_input_iterator.h134
-rw-r--r--tensorflow/core/util/transform_output_iterator.h149
-rw-r--r--tensorflow/python/kernel_tests/BUILD20
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test_big.py75
16 files changed, 115 insertions, 1486 deletions
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 6507a9a5e0..25f00de81d 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -289,8 +289,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Failing with TF 1.3 (TODO)
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
- # Test should only be run manually
- "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py"
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 49b1589929..1f7eb87f18 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -412,7 +412,6 @@ tf_cuda_library(
"util/guarded_philox_random.h",
"util/mirror_pad_mode.h",
"util/padding.h",
- "util/permutation_input_iterator.h",
"util/port.h",
"util/saved_tensor_slice_util.h",
"util/sparse/group_iterator.h",
@@ -424,7 +423,6 @@ tf_cuda_library(
"util/tensor_slice_reader.h",
"util/tensor_slice_reader_cache.h",
"util/tensor_slice_writer.h",
- "util/transform_output_iterator.h",
"util/use_cudnn.h",
"util/matmul_autotune.h",
"util/util.h",
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 7dd56247f4..9f638eebee 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2532,9 +2532,8 @@ tf_kernel_library(
tf_kernel_library(
name = "reduction_ops",
- srcs = ["reduction_ops_gpu_kernels.h"],
prefix = "reduction_ops",
- deps = MATH_DEPS + if_cuda(["@cub_archive//:cub"]),
+ deps = MATH_DEPS,
)
tf_kernel_library(
@@ -2995,16 +2994,14 @@ tf_kernel_library(
tf_kernel_library(
name = "l2loss_op",
prefix = "l2loss_op",
- #srcs = ["reduction_ops_gpu_kernels.h"],
deps = [
- ":reduction_ops",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_grad",
"//tensorflow/core:nn_ops_op_lib",
- ] + if_cuda(["@cub_archive//:cub"]),
+ "//third_party/eigen3",
+ ],
)
tf_cuda_cc_test(
diff --git a/tensorflow/core/kernels/l2loss_op.cc b/tensorflow/core/kernels/l2loss_op.cc
index f8ed935157..9875cd027d 100644
--- a/tensorflow/core/kernels/l2loss_op.cc
+++ b/tensorflow/core/kernels/l2loss_op.cc
@@ -27,9 +27,10 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
-template <typename T>
-class L2LossOp<CPUDevice, T> : public OpKernel {
+template <typename Device, typename T>
+class L2LossOp : public OpKernel {
public:
explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -41,9 +42,8 @@ class L2LossOp<CPUDevice, T> : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output));
- const CPUDevice& d = context->eigen_device<CPUDevice>();
- output->scalar<T>().device(d) =
- (input.flat<T>().square() * static_cast<T>(0.5)).sum();
+ functor::L2Loss<Device, T>()(context->eigen_device<Device>(),
+ input.flat<T>(), output->scalar<T>());
}
};
@@ -57,4 +57,33 @@ REGISTER_KERNEL(double);
REGISTER_KERNEL(Eigen::half);
#undef REGISTER_KERNEL
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void L2Loss<GPUDevice, T>::operator()(const GPUDevice& d, \
+ typename TTypes<T>::ConstTensor input, \
+ typename TTypes<T>::Scalar output); \
+ extern template struct L2Loss<GPUDevice, T>;
+
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+DECLARE_GPU_SPEC(Eigen::half);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("L2Loss").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ L2LossOp<GPUDevice, T>);
+
+REGISTER_GPU_KERNEL(float);
+REGISTER_GPU_KERNEL(double);
+REGISTER_GPU_KERNEL(Eigen::half);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/l2loss_op.h b/tensorflow/core/kernels/l2loss_op.h
index 4953aa237c..f7204cefdd 100644
--- a/tensorflow/core/kernels/l2loss_op.h
+++ b/tensorflow/core/kernels/l2loss_op.h
@@ -15,19 +15,25 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_L2LOSS_OP_H_
#define TENSORFLOW_KERNELS_L2LOSS_OP_H_
+// Functor definition for L2LossOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
+namespace functor {
+// Functor used by L2LossOp to do the computations.
template <typename Device, typename T>
-struct L2LossOp : public OpKernel {
- explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) {}
+struct L2Loss {
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor input,
+ typename TTypes<T>::Scalar output) {
+ // We flatten the input tensor and reduce on dimension 0, producing
+ // a single number which is Mul(Sum(x^2), 0.5).
+ output.device(d) = (input.square() * static_cast<T>(0.5)).sum();
+ }
};
+} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_L2LOSS_OP_H_
diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
index 73b6472254..420df37086 100644
--- a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
@@ -21,55 +21,12 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/kernels/reduction_ops_common.h"
-#include "tensorflow/core/kernels/reduction_ops_gpu_kernels.h"
-
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-
-// TODO(eriche): can add specialization for half2
-template <typename T>
-struct squareHalf {
- __host__ __device__ T operator()(const T& x) const {
- return static_cast<T>(0.5) * x * x;
- }
-};
-
-template <typename T>
-class L2LossOp<GPUDevice, T> : public OpKernel {
- public:
- explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // The input tensor can be of any number of dimensions, even though it's
- // 2D in most typical applications.
- const Tensor& input = context->input(0);
- // The output is a single number.
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, TensorShape({}), &output));
- typedef cub::TransformInputIterator<T, squareHalf<T>, T*> inputIterType;
- inputIterType input_itr((T*)input.flat<T>().data(), squareHalf<T>());
- typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
-
- Constants<GPUDevice> constants;
- functor::ReduceImpl<T, cub::Sum, T*, inputIterType, ReductionAxes>(
- context, (T*)output->flat<T>().data(), input_itr, 1,
- input.flat<T>().size(), 1, 1, 0, constants.kZero, cub::Sum(), T(0));
- }
-};
-
-// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNEL(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("L2Loss").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
- L2LossOp<GPUDevice, T>);
-
-REGISTER_GPU_KERNEL(float);
-REGISTER_GPU_KERNEL(double);
-REGISTER_GPU_KERNEL(Eigen::half);
-#undef REGISTER_GPU_KERNEL
+template struct functor::L2Loss<GPUDevice, float>;
+template struct functor::L2Loss<GPUDevice, double>;
+template struct functor::L2Loss<GPUDevice, Eigen::half>;
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h
index e43d2828f3..5db9e6032e 100644
--- a/tensorflow/core/kernels/reduction_ops.h
+++ b/tensorflow/core/kernels/reduction_ops.h
@@ -20,7 +20,6 @@ limitations under the License.
#include <iostream>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
@@ -68,7 +67,7 @@ void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) {
template <typename Device, typename Reducer>
struct ReduceFunctor {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
+ static void Reduce(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Reducer& reducer);
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 71af9d88dc..553f889523 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -190,24 +190,24 @@ class ReductionOp : public OpKernel {
Functor::FillIdentity(d, tmp_out.flat<T>(), reducer);
} else if ((helper.ndims() == 1) && helper.reduce_first_axis()) {
// Reduce to a scalar.
- Functor::Reduce(ctx, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data),
+ Functor::Reduce(d, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data),
constants.kZero, reducer);
} else if ((helper.ndims() == 2) && helper.reduce_first_axis()) {
// Can be viewed as a reduction of a matrix along 1st dimension.
- Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
+ Functor::Reduce(d, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
constants.kZero, reducer);
} else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) {
// Can be viewed as a reduction of a matrix along 2nd dimension.
- Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
+ Functor::Reduce(d, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
constants.kOne, reducer);
} else if ((helper.ndims() == 3) && helper.reduce_first_axis()) {
// Can be viewed as a reduction of a 3D tensor along 1st and 3rd
// dimensions.
- Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 3>(data),
+ Functor::Reduce(d, helper.out<T, 1>(&tmp_out), helper.in<T, 3>(data),
constants.kZeroTwo, reducer);
} else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) {
// Can be viewed as a reduction of a 3D tensor along 2nd dimension.
- Functor::Reduce(ctx, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data),
+ Functor::Reduce(d, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data),
constants.kOne, reducer);
} else {
// If we don't hit one of the cases above, transpose the data so that
@@ -223,7 +223,7 @@ class ReductionOp : public OpKernel {
const int64 unreduced = tmp_out.NumElements();
const int64 reduced = shuffled.NumElements() / unreduced;
const Tensor& const_shuffled = shuffled;
- Functor::Reduce(ctx, tmp_out.flat<T>(),
+ Functor::Reduce(d, tmp_out.flat<T>(),
const_shuffled.shaped<T, 2>({unreduced, reduced}),
constants.kOne, reducer);
}
@@ -258,10 +258,9 @@ namespace functor {
template <typename Device, typename Reducer>
struct ReduceFunctorBase {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
+ static void Reduce(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Reducer& reducer) {
- const Device& d = ctx->eigen_device<Device>();
ReduceEigenImpl(d, out, in, reduction_axes, reducer);
}
diff --git a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
index cff0e95bc1..ec4490db83 100644
--- a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
@@ -17,7 +17,8 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/kernels/reduction_ops_gpu_kernels.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/kernels/reduction_ops.h"
namespace tensorflow {
namespace functor {
@@ -32,27 +33,15 @@ typedef TTypes<float>::Tensor::Index Index;
template <typename Reducer>
struct ReduceFunctor<GPUDevice, Reducer> {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
+ static void Reduce(const GPUDevice& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
- const Reducer& reducer);
-};
-
-template <typename T>
-struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::SumReducer<T>& reducer) {
- ReduceImpl<T, cub::Sum, T*, T*, ReductionAxes>(
- ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
- cub::Sum(), T(0));
+ const Reducer& reducer) {
+ ReduceEigenImpl(d, To32Bit(out), To32Bit(in), reduction_axes, reducer);
}
template <typename OUT_T>
static void FillIdentity(const GPUDevice& d, OUT_T out,
- const Eigen::internal::SumReducer<T>& reducer) {
+ const Reducer& reducer) {
FillIdentityEigenImpl(d, To32Bit(out), reducer);
}
};
@@ -60,30 +49,19 @@ struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> {
template <typename T>
struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<T>> {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
+ static void Reduce(const GPUDevice& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Eigen::internal::MeanReducer<T>& reducer) {
- int divisor = 1;
- if (out.rank() == 0)
- divisor = in.size();
- else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
- divisor = in.dimension(0);
- else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
- divisor = in.dimension(1);
- else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
- reduction_axes[1] == 2)
- divisor = in.dimension(0) * in.dimension(2);
- else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
- divisor = in.dimension(1);
-
- DividesBy<T> div_op((T)divisor);
- TransformOutputIterator<T, T, DividesBy<T>> itr((T*)out.data(), div_op);
- ReduceImpl<T, cub::Sum, TransformOutputIterator<T, T, DividesBy<T>>, T*,
- ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(),
- in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(),
- reduction_axes, cub::Sum(), T(0));
+ typedef typename IN_T::Index Index;
+ // Eigen sum reductions are much faster on GPU than mean reductions:
+ // Simply trigger them by computing the sum of the weighted inputs.
+ Index num_coeffs_to_reduce = 1;
+ for (int i = 0; i < Eigen::internal::array_size<ReductionAxes>::value;
+ ++i) {
+ num_coeffs_to_reduce *= in.dimension(reduction_axes[i]);
+ }
+ T scale = T(1.0 / num_coeffs_to_reduce);
+ out.device(d) = (in * scale).sum(reduction_axes);
}
template <typename OUT_T>
@@ -93,159 +71,15 @@ struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<T>> {
}
};
-template <>
-struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<Eigen::half>> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::MeanReducer<Eigen::half>& reducer) {
- float divisor = 1.f;
- if (out.rank() == 0)
- divisor = in.size();
- else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
- divisor = in.dimension(0);
- else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
- divisor = in.dimension(1);
- else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
- reduction_axes[1] == 2)
- divisor = in.dimension(0) * in.dimension(2);
- else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
- divisor = in.dimension(1);
- DividesBy<float, Eigen::half> div_op(divisor);
-
- typedef cub::TransformInputIterator<float, HalfToFloat, Eigen::half*>
- inputIterType;
- inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat());
-
- typedef TransformOutputIterator<Eigen::half, float,
- DividesBy<float, Eigen::half>>
- outputIterType;
- outputIterType itr((Eigen::half*)out.data(), div_op);
-
- ReduceImpl<float, cub::Sum, outputIterType, inputIterType, ReductionAxes>(
- ctx, itr, input_itr, in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
- cub::Sum(), 0.f);
- }
-
- template <typename OUT_T>
- static void FillIdentity(
- const GPUDevice& d, OUT_T out,
- const Eigen::internal::MeanReducer<Eigen::half>& reducer) {
- FillIdentityEigenImpl(d, To32Bit(out), reducer);
- }
-};
-
-template <typename T>
-struct ReduceFunctor<GPUDevice, Eigen::internal::MaxReducer<T>> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::MaxReducer<T>& reducer) {
- ReduceImpl<T, cub::Max, T*, T*, ReductionAxes>(
- ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
- cub::Max(), std::numeric_limits<T>::min());
- }
-
- template <typename OUT_T>
- static void FillIdentity(const GPUDevice& d, OUT_T out,
- const Eigen::internal::MaxReducer<T>& reducer) {
- FillIdentityEigenImpl(d, To32Bit(out), reducer);
- }
-};
-
-template <typename T>
-struct ReduceFunctor<GPUDevice, Eigen::internal::MinReducer<T>> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::MinReducer<T>& reducer) {
- ReduceImpl<T, cub::Min, T*, T*, ReductionAxes>(
- ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
- cub::Min(), std::numeric_limits<T>::max());
- }
-
- template <typename OUT_T>
- static void FillIdentity(const GPUDevice& d, OUT_T out,
- const Eigen::internal::MinReducer<T>& reducer) {
- FillIdentityEigenImpl(d, To32Bit(out), reducer);
- }
-};
-
-template <typename T>
-struct ReduceFunctor<GPUDevice, Eigen::internal::ProdReducer<T>> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::ProdReducer<T>& reducer) {
- ReduceImpl<T, Prod<T>, T*, T*, ReductionAxes>(
- ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
- Prod<T>(), T(1));
- }
-
- template <typename OUT_T>
- static void FillIdentity(const GPUDevice& d, OUT_T out,
- const Eigen::internal::ProdReducer<T>& reducer) {
- FillIdentityEigenImpl(d, To32Bit(out), reducer);
- }
-};
-
-template <>
-struct ReduceFunctor<GPUDevice, Eigen::internal::AndReducer> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::AndReducer& reducer) {
- ReduceImpl<bool, And, bool*, bool*, ReductionAxes>(
- ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, And(),
- true);
- }
-
- template <typename OUT_T>
- static void FillIdentity(const GPUDevice& d, OUT_T out,
- const Eigen::internal::AndReducer& reducer) {
- FillIdentityEigenImpl(d, To32Bit(out), reducer);
- }
-};
-
-template <>
-struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
- template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
- const ReductionAxes& reduction_axes,
- const Eigen::internal::OrReducer& reducer) {
- ReduceImpl<bool, Or, bool*, bool*, ReductionAxes>(
- ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
- in.rank() >= 2 ? in.dimension(1) : 1,
- in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or(),
- false);
- }
-
- template <typename OUT_T>
- static void FillIdentity(const GPUDevice& d, OUT_T out,
- const Eigen::internal::OrReducer& reducer) {
- FillIdentityEigenImpl(d, To32Bit(out), reducer);
- }
-};
-
// T: the data type
// REDUCER: the reducer functor
// NUM_AXES: the number of axes to reduce
// IN_DIMS: the number of dimensions of the input tensor
-#define DEFINE(T, REDUCER, IN_DIMS, NUM_AXES) \
- template void ReduceFunctor<GPUDevice, REDUCER>::Reduce( \
- OpKernelContext* ctx, TTypes<T, IN_DIMS - NUM_AXES>::Tensor out, \
- TTypes<T, IN_DIMS>::ConstTensor in, \
- const Eigen::array<Index, NUM_AXES>& reduction_axes, \
+#define DEFINE(T, REDUCER, IN_DIMS, NUM_AXES) \
+ template void ReduceFunctor<GPUDevice, REDUCER>::Reduce( \
+ const GPUDevice& d, TTypes<T, IN_DIMS - NUM_AXES>::Tensor out, \
+ TTypes<T, IN_DIMS>::ConstTensor in, \
+ const Eigen::array<Index, NUM_AXES>& reduction_axes, \
const REDUCER& reducer);
#define DEFINE_IDENTITY(T, REDUCER) \
diff --git a/tensorflow/core/kernels/reduction_ops_gpu_kernels.h b/tensorflow/core/kernels/reduction_ops_gpu_kernels.h
deleted file mode 100644
index 45a9fdd6d9..0000000000
--- a/tensorflow/core/kernels/reduction_ops_gpu_kernels.h
+++ /dev/null
@@ -1,697 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#if GOOGLE_CUDA
-
-#define EIGEN_USE_GPU
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
-#include "cuda/include/cuComplex.h"
-#include "tensorflow/core/framework/numeric_types.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/kernels/reduction_ops.h"
-#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/cuda_kernel_helper.h"
-#include "tensorflow/core/util/permutation_input_iterator.h"
-#include "tensorflow/core/util/transform_output_iterator.h"
-
-#include <sstream>
-
-namespace tensorflow {
-namespace functor {
-
-typedef Eigen::GpuDevice GPUDevice;
-
-template <typename T>
-struct Prod {
- __host__ __device__ T operator()(const T& a, const T& b) const {
- return a * b;
- }
-};
-
-// needed to work around a compiler bug in nvcc - it doesn't seem to like
-// the overloaded multiply op for std::complex
-template <>
-struct Prod<std::complex<float>> {
- __host__ __device__ std::complex<float> operator()(
- const std::complex<float>& a, const std::complex<float>& b) const {
- auto result = cuCmulf(make_cuComplex(a.real(), a.imag()),
- make_cuComplex(b.real(), b.imag()));
- return std::complex<float>(result.x, result.y);
- }
-};
-
-template <>
-struct Prod<std::complex<double>> {
- __host__ __device__ std::complex<double> operator()(
- const std::complex<double>& a, const std::complex<double>& b) const {
- auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()),
- make_cuDoubleComplex(b.real(), b.imag()));
- return std::complex<double>(result.x, result.y);
- }
-};
-
-template <typename T, typename outT = T>
-struct DividesBy {
- T divisor;
-
- __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {}
-
- __host__ __device__ outT operator()(const T& x) const { return x / divisor; }
-};
-
-// needed to work around a compiler bug in nvcc - it doesn't seem to like
-// the overloaded ops for std::complex
-template <>
-struct DividesBy<std::complex<float>> {
- cuFloatComplex divisor;
-
- __host__ __device__ explicit DividesBy(std::complex<float> divisor)
- : divisor(make_cuComplex(divisor.real(), divisor.imag())) {}
-
- // implements
- __host__ __device__ std::complex<float> operator()(
- const std::complex<float>& x) const {
- auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor);
- return std::complex<float>(result.x, result.y);
- }
-};
-
-template <>
-struct DividesBy<std::complex<double>> {
- cuDoubleComplex divisor;
-
- __host__ __device__ explicit DividesBy(std::complex<double> divisor)
- : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {}
-
- // implements
- __host__ __device__ std::complex<double> operator()(
- const std::complex<double>& x) const {
- auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor);
- return std::complex<double>(result.x, result.y);
- }
-};
-
-template <>
-struct DividesBy<float, Eigen::half> {
- float divisor;
-
- __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {}
-
- __host__ __device__ Eigen::half operator()(const float& x) const {
- return Eigen::half(x / divisor);
- }
-};
-
-struct HalfToFloat {
- __host__ __device__ float operator()(const Eigen::half& x) const {
- return Eigen::half_impl::half_to_float(x);
- }
-};
-
-struct FloatToHalf {
- __host__ __device__ Eigen::half operator()(const float& x) const {
- return Eigen::half_impl::float_to_half_rtne(x);
- }
-};
-
-struct And {
- __host__ __device__ bool operator()(const bool& a, const bool& b) const {
- return a && b;
- }
-};
-
-struct Or {
- __host__ __device__ bool operator()(const bool& a, const bool& b) const {
- return a || b;
- }
-};
-
-// each block does a grid strided loop and reduces its values locally
-// the case of one block is used for low latency small reductions to scalars
-template <typename T, typename outT, int num_threads, typename Op>
-__global__ void BlockReduceKernel(T in, outT out, int num_elems, Op op) {
- const int bid = blockIdx.x;
- const int tid = threadIdx.x;
-
- const int gid = bid * blockDim.x + tid;
- const int stride = blockDim.x * gridDim.x;
-
- typedef typename std::iterator_traits<T>::value_type value_type;
-
- value_type sum;
- if (gid < num_elems) {
- sum = in[gid];
- for (int pos = gid + stride; pos < num_elems; pos += stride) {
- sum = op(sum, in[pos]);
- }
- } else
- sum = value_type(); // stop compiler from complaining
-
- typedef cub::BlockReduce<value_type, num_threads> BlockReduce;
-
- __shared__ typename BlockReduce::TempStorage temp_storage;
-
- __syncthreads();
-
- sum = BlockReduce(temp_storage)
- .template Reduce(sum, op, min(num_elems, num_threads));
-
- if (tid == 0) out[bid] = sum;
-}
-
-// maps a warp to each row
-template <typename T, typename outT, typename Op>
-__global__ void RowReduceKernel(T in, outT out, int num_rows, int num_cols,
- Op op) {
- typedef typename std::iterator_traits<T>::value_type value_type;
- const int row = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
- const int lane = threadIdx.x % 32;
-
- if (num_cols == 1) {
- int gid = threadIdx.x + blockIdx.x * blockDim.x;
- if (gid < num_rows) out[gid] = in[gid];
- return;
- }
-
- value_type sum;
- int col = lane;
- if (row < num_rows && col < num_cols) {
- sum = in[row * num_cols + col];
- col += 32;
- for (; col < num_cols; col += 32) {
- sum = op(sum, in[row * num_cols + col]);
- }
- } else {
- sum = value_type(); // stop compiler from complaining
- }
-
- typedef cub::WarpReduce<value_type> WarpReduce;
-
- __shared__ typename WarpReduce::TempStorage temp_storage;
-
- __syncthreads();
-
- sum = WarpReduce(temp_storage).template Reduce(sum, op, min(num_cols, 32));
-
- if (row < num_rows && lane == 0) out[row] = sum;
-}
-
-// Works only if there are <= 16 columns
-// each warps sums over multiple rows at once
-template <typename T, typename outT, typename Op>
-__global__ void ColumnReduceMax16ColumnsKernel(T in, outT out, int num_rows,
- int num_cols, Op op) {
- typedef typename std::iterator_traits<T>::value_type value_type;
- int rows_per_warp = 32 / num_cols;
-
- int lane = threadIdx.x % 32;
- int lane_row = lane / num_cols;
-
- const int start_row_warp =
- rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y);
- const int start_row_lane = start_row_warp + lane_row;
- int row = start_row_lane;
- int col = lane % num_cols;
-
- value_type sum;
- if (row * num_cols + col < num_rows * num_cols)
- sum = in[row * num_cols + col];
- else
- sum = value_type(); // needed to shut up compiler
-
- __shared__ value_type partial_sums[32][33];
-
- __syncthreads();
-
- row += rows_per_warp * gridDim.y * blockDim.y;
- for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
- int global_pos = row * num_cols + col;
- if (global_pos < (num_rows * num_cols))
- sum = op(sum, in[row * num_cols + col]);
- }
-
- const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
- // not the most efficient way to do this sum
- for (int i = 1; i < rows_in_this_warp; ++i) {
- value_type tmp =
- cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff);
- if (lane < num_cols) sum = op(sum, tmp);
- }
-
- if (lane < num_cols) partial_sums[lane][threadIdx.y] = sum;
-
- __syncthreads();
-
- if (threadIdx.y == 0 && threadIdx.x < num_cols) {
- value_type s = partial_sums[threadIdx.x][0];
-
- if (blockDim.y > 1) {
- for (int row = 1; row < blockDim.y; ++row) {
- s = op(s, partial_sums[threadIdx.x][row]);
- }
- }
-
- out[col * gridDim.y + blockIdx.y] = s;
- }
-}
-
-// Maps each block to a column range 32 wide
-template <typename T, typename outT, typename Op>
-__global__ void ColumnReduceKernel(T in, outT out, int num_rows, int num_cols,
- Op op) {
- typedef typename std::iterator_traits<T>::value_type value_type;
- int row = blockIdx.y * blockDim.y + threadIdx.y;
- int col = blockIdx.x * 32 + threadIdx.x;
-
- value_type sum;
- if (row * num_cols + col < num_rows * num_cols)
- sum = in[row * num_cols + col];
- else
- sum = value_type(); // will never be used, needed to shut up compiler
-
- __shared__ value_type partial_sums[32][33];
-
- __syncthreads();
-
- row += gridDim.y * blockDim.y;
-
- if (col < num_cols) {
- for (; row < num_rows; row += gridDim.y * blockDim.y) {
- sum = op(sum, in[row * num_cols + col]);
- }
- }
-
- partial_sums[threadIdx.x][threadIdx.y] = sum;
-
- __syncthreads();
-
- if (threadIdx.y == 0 && threadIdx.x < 32) {
- value_type s = partial_sums[threadIdx.x][0];
-
- for (int row = 1; row < blockDim.y; ++row) {
- s = op(s, partial_sums[threadIdx.x][row]);
- }
-
- out[col * gridDim.y + blockIdx.y] = s;
- }
-}
-
-// does multiple warp size segmented reductions in parallel
-// segments cannot cross warp boundaries (mainly used for reducing the segments
-// that come from the Max16Columns column reduction kernel)
-template <typename T, typename outT, typename Op>
-__global__ void CleanupSegments(T partial_sums, outT out, int num_rows,
- int num_cols, int segment_size, Op op) {
- typedef typename std::iterator_traits<T>::value_type value_type;
- const int tid = threadIdx.x + blockIdx.x * blockDim.x;
-
- value_type val;
- if (tid < segment_size * num_cols)
- val = partial_sums[tid];
- else
- val = value_type(); // 0s beyond last segment won't be used, so OK
-
- typedef cub::WarpReduce<value_type> WarpReduce;
-
- __shared__ typename WarpReduce::TempStorage temp_storage;
-
- __syncthreads();
-
- bool head_flag = (threadIdx.x % segment_size) == 0;
- value_type sum =
- WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op);
-
- if (head_flag && tid < segment_size * num_cols) {
- out[tid / segment_size] = sum;
- }
-}
-
-// assigns one thread to a column
-template <typename T, typename outT, typename Op>
-__global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes,
- int num_rows, int num_cols, Op op) {
- typedef typename std::iterator_traits<T>::value_type value_type;
- const int gid = threadIdx.x + blockIdx.x * blockDim.x;
- const int elems_per_plane = num_rows * num_cols;
-
- int plane = gid / num_cols;
- int col = gid % num_cols;
-
- if (plane >= num_planes) return;
-
- if (num_rows == 1) {
- out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col];
- return;
- }
-
- value_type sum = op(in[plane * elems_per_plane + col],
- in[plane * elems_per_plane + num_cols + col]);
- for (int row = 2; row < num_rows; ++row) {
- sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]);
- }
-
- out[plane * num_cols + col] = sum;
-}
-
-struct RowOffset {
- __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {}
-
- __host__ __device__ int operator()(const int& x) const { return cols_ * x; }
-
- int cols_;
-};
-
-struct GatherOp {
- __host__ __device__ GatherOp(const int& extent_x, const int& extent_y,
- const int& extent_z, bool kOne)
- : extent_x_(extent_x),
- extent_y_(extent_y),
- extent_z_(extent_z),
- kOne_(kOne) {
- if (kOne_)
- group_size_ = extent_y_;
- else
- group_size_ = extent_x_ * extent_z_;
- }
-
- __host__ __device__ int operator()(const int& ind) const {
- const int group = kOne_ ? ind / group_size_ : ind % group_size_;
- const int offset = kOne_ ? ind % group_size_ : ind / group_size_;
-
- const int x = group / extent_z_;
- const int z = group % extent_z_;
-
- return x * extent_y_ * extent_z_ + z + offset * extent_z_;
- }
-
- int extent_x_;
- int extent_y_;
- int extent_z_;
- bool kOne_;
- int group_size_;
-};
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
- int in_size, Op op, T init,
- const cudaStream_t& cu_stream) {
- // handle situations where low latency is important better than CUB
- if (in_size <= 4096) {
- const int num_blocks = 1;
- const int num_threads = 256;
- BlockReduceKernel<IN_T, OUT_T, num_threads>
- <<<num_blocks, num_threads, 0, cu_stream>>>(in, out, in_size, op);
- return;
- } else if (in_size <= 1 << 19) {
- const int num_threads = 256;
- const int num_blocks = 32; // it seems like tailoring this to the GPU
- // would be more effective, but all attempts
- // at making this a multiple of the number of
- // multiprocessors have lead to lower perf
- // in general
- // TODO(eriche) investigate this more
-
- Tensor temp_storage;
- OP_REQUIRES_OK(
- ctx,
- ctx->allocate_temp(
- DT_INT8, TensorShape({static_cast<int64>(num_blocks * sizeof(T))}),
- &temp_storage));
-
- BlockReduceKernel<IN_T, T*, num_threads>
- <<<num_blocks, num_threads, 0, cu_stream>>>(
- in, (T*)temp_storage.flat<int8_t>().data(), in_size, op);
-
- CleanupSegments<<<1, num_blocks, 0, cu_stream>>>(
- (T*)temp_storage.flat<int8_t>().data(), out, 1, 1, num_blocks, op);
- return;
- }
- std::size_t temp_storage_bytes = 0;
-
- Tensor temp_storage;
- // written as a loop because it reduces clutter
- // first pass allocates memory, second launches kernel(s)
- for (int i = 0; i < 2; ++i) {
- auto success = cub::DeviceReduce::Reduce(
- i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
- temp_storage_bytes, in, out, in_size, op, init, cu_stream);
-
- OP_REQUIRES(
- ctx, success == 0,
- errors::Internal("CUB reduce error", cudaGetErrorString(success)));
-
- if (i == 0)
- OP_REQUIRES_OK(
- ctx,
- ctx->allocate_temp(
- DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
- &temp_storage));
- }
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows,
- int num_cols, Op op, T init,
- const cudaStream_t& cu_stream) {
- if (num_cols < 1024) {
- const int threads_per_block = 128;
- const int warps_per_block = threads_per_block / 32;
- int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block;
-
- RowReduceKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
- in, out, num_rows, num_cols, op);
- return;
- }
-
- // setup segment offsets with counting and transform iterator
- RowOffset row_offset_op(num_cols);
- cub::CountingInputIterator<int> counting_iter(0);
- cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
- transform_iter(counting_iter, row_offset_op);
-
- std::size_t temp_storage_bytes = 0;
- Tensor temp_storage;
- for (int i = 0; i < 2; ++i) {
- auto success = cub::DeviceSegmentedReduce::Reduce(
- i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
- temp_storage_bytes, in, out, num_rows, transform_iter,
- transform_iter + 1, op, init, cu_stream);
-
- OP_REQUIRES(ctx, success == 0,
- errors::Internal("CUB segmented reduce error",
- cudaGetErrorString(success)));
-
- if (i == 0)
- OP_REQUIRES_OK(
- ctx,
- ctx->allocate_temp(
- DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
- &temp_storage));
- }
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
- int extent_x, int extent_y, Op op, T init,
- const cudaStream_t& cu_stream) {
- int rows_per_warp = 32 / extent_y;
- dim3 block_dim(32, min(Eigen::divup(extent_x, rows_per_warp), 32), 1);
- dim3 grid_dim(1,
- Eigen::divup(static_cast<unsigned int>(extent_x),
- rows_per_warp * block_dim.y),
- 1);
-
- grid_dim.y = min((int)grid_dim.y, 32);
-
- if (grid_dim.y > 2 && grid_dim.y < 32) {
- int log2 = Log2Floor(grid_dim.y);
- grid_dim.y = 1 << log2;
- }
-
- if (grid_dim.y == 1) {
- ColumnReduceMax16ColumnsKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
- in, out, extent_x, extent_y, op);
- } else {
- Tensor temp_storage;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_INT8,
- TensorShape({static_cast<int64>(
- sizeof(T) * extent_y * grid_dim.y)}),
- &temp_storage));
- ColumnReduceMax16ColumnsKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
- in, (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op);
-
- dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1);
- dim3 num_threads(128, 1, 1);
- CleanupSegments<<<new_grid_dim, block_dim, 0, cu_stream>>>(
- (T*)temp_storage.flat<int8_t>().data(), out, extent_x, extent_y,
- grid_dim.y, op);
- }
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
- int extent_x, int extent_y, Op op,
- T init, const cudaStream_t& cu_stream) {
- dim3 block_dim(32, min(extent_x, 32), 1);
- dim3 grid_dim((extent_y + 31) / 32, 1, 1);
-
- if (grid_dim.x < 16) grid_dim.y = min((extent_x + 31) / 32, 32);
-
- if (grid_dim.y > 2 && grid_dim.y < 32) {
- int log2 = Log2Floor(grid_dim.y);
- grid_dim.y = 1 << log2;
- }
-
- if (grid_dim.y == 1) {
- ColumnReduceKernel<<<grid_dim, block_dim, 0, cu_stream>>>(in, out, extent_x,
- extent_y, op);
- } else {
- Tensor temp_storage;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_INT8,
- TensorShape({static_cast<int64>(
- sizeof(T) * extent_y * grid_dim.y)}),
- &temp_storage));
-
- ColumnReduceKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
- in, (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op);
-
- dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1);
- dim3 num_threads(128, 1, 1);
- CleanupSegments<<<new_grid_dim, block_dim, 0, cu_stream>>>(
- (T*)temp_storage.flat<int8_t>().data(), out, extent_x, extent_y,
- grid_dim.y, op);
- }
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
- int extent_x, int extent_y, Op op, T init,
- const cudaStream_t& cu_stream) {
- if (extent_y <= 16) {
- LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init,
- cu_stream);
- } else if (extent_y <= 4096) {
- LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op,
- init, cu_stream);
- } else {
- int threads_per_block = 128;
- int num_blocks = Eigen::divup(extent_y, threads_per_block);
-
- ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
- in, out, 1, extent_x, extent_y, op);
- }
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
- int extent_y, int extent_z, Op op, T init,
- const cudaStream_t& cu_stream) {
- int threads_per_block = 128;
- int num_blocks =
- (extent_x * extent_z + threads_per_block - 1) / threads_per_block;
-
- // TODO (eriche): this won't be very good in the case of small x
- // small z and large y.
- ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
- in, out, extent_x, extent_y, extent_z, op);
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T>
-void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
- int extent_y, int extent_z, Op op, T init,
- const cudaStream_t& cu_stream) {
- // setup segment offsets with counting and transform iterator
- RowOffset row_offset_op(extent_x * extent_z);
- cub::CountingInputIterator<int> counting_iter(0);
- cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
- transform_iter(counting_iter, row_offset_op);
-
- GatherOp gather_op(extent_x, extent_y, extent_z, false);
- typedef cub::TransformInputIterator<int, GatherOp,
- cub::CountingInputIterator<int>>
- gatherIterType;
- gatherIterType gather_iter(counting_iter, gather_op);
-
- PermutationInputIterator<T, IN_T, gatherIterType> permute_iter(in,
- gather_iter);
-
- std::size_t temp_storage_bytes = 0;
- Tensor temp_storage;
-
- for (int i = 0; i < 2; ++i) {
- auto success = cub::DeviceSegmentedReduce::Reduce(
- i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
- temp_storage_bytes, permute_iter, out, extent_y, transform_iter,
- transform_iter + 1, op, init, cu_stream);
-
- OP_REQUIRES(ctx, success == 0,
- errors::Internal("CUB segmented reduce error",
- cudaGetErrorString(success)));
-
- if (i == 0)
- OP_REQUIRES_OK(
- ctx,
- ctx->allocate_temp(
- DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
- &temp_storage));
- }
-}
-
-template <typename T, typename Op, typename OUT_T, typename IN_T,
- typename ReductionAxes>
-void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank,
- int in_dim0, int in_dim1, int in_dim2, int out_rank,
- const ReductionAxes& reduction_axes, Op op, T init) {
- const cudaStream_t& cu_stream = GetCudaStream(ctx);
- if (out_rank == 0) {
- const int in_size = in_dim0 * in_dim1 * in_dim2;
- LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream);
- } else if (in_rank == 2 && out_rank == 1 &&
- reduction_axes[0] == 1) { // row reduction
- LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
- } else if (in_rank == 2 && out_rank == 1 &&
- reduction_axes[0] == 0) { // column reduction
- LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
- } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) {
- Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
- cu_stream);
- } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 &&
- reduction_axes[1] == 2) {
- Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
- cu_stream);
- } else {
- std::stringstream ss;
- ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank
- << " " << out_rank;
- if (out_rank == 1) ss << " " << reduction_axes[0];
- if (out_rank == 2) ss << " " << reduction_axes[1];
- LOG(FATAL) << ss.str();
- }
-}
-
-} // namespace functor
-} // namespace tensorflow
-
-#endif
diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc
index 9bbe993a2f..9cdebdd4f2 100644
--- a/tensorflow/core/kernels/reduction_ops_test.cc
+++ b/tensorflow/core/kernels/reduction_ops_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -23,59 +22,14 @@ namespace tensorflow {
// Creates a Graph which "reduce"s a 3D float tensor of "num" elements
// into a scalar.
-template <typename T>
-static Graph* ToScalar(const string& reduce, int num_x, int num_y) {
- auto* g = new Graph(OpRegistry::Global());
- Tensor data(DataTypeToEnum<T>::value, TensorShape({num_x, num_y}));
- data.flat<T>().setRandom();
- Tensor axes(DT_INT32, TensorShape({2}));
- axes.flat<int32>()(0) = 0;
- axes.flat<int32>()(1) = 1;
- test::graph::Reduce(g, reduce, test::graph::Constant(g, data),
- test::graph::Constant(g, axes));
- return g;
-}
-
-static Graph* ColReduce(const string& reduce, int num_x, int num_y) {
- auto* g = new Graph(OpRegistry::Global());
- Tensor data(DT_FLOAT, TensorShape({num_x, num_y}));
+static Graph* ToScalar(const string& reduce, int num) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
data.flat<float>().setRandom();
- Tensor axes(DT_INT32, TensorShape({1}));
+ Tensor axes(DT_INT32, TensorShape({3}));
axes.flat<int32>()(0) = 0;
- test::graph::Reduce(g, reduce, test::graph::Constant(g, data),
- test::graph::Constant(g, axes));
- return g;
-}
-
-static Graph* RowReduce(const string& reduce, int num_x, int num_y) {
- auto* g = new Graph(OpRegistry::Global());
- Tensor data(DT_FLOAT, TensorShape({num_x, num_y}));
- data.flat<float>().setRandom();
- Tensor axes(DT_INT32, TensorShape({1}));
- axes.flat<int32>()(0) = 1;
- test::graph::Reduce(g, reduce, test::graph::Constant(g, data),
- test::graph::Constant(g, axes));
- return g;
-}
-
-static Graph* ThreeDYReduce(const string& reduce, int num_y, int num_z) {
- auto* g = new Graph(OpRegistry::Global());
- Tensor data(DT_FLOAT, TensorShape({4, num_y, num_z}));
- data.flat<float>().setRandom();
- Tensor axes(DT_INT32, TensorShape({1}));
- axes.flat<int32>()(0) = 1;
- test::graph::Reduce(g, reduce, test::graph::Constant(g, data),
- test::graph::Constant(g, axes));
- return g;
-}
-
-static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) {
- auto* g = new Graph(OpRegistry::Global());
- Tensor data(DT_FLOAT, TensorShape({4, num_y, num_z}));
- data.flat<float>().setRandom();
- Tensor axes(DT_INT32, TensorShape({2}));
- axes.flat<int32>()(0) = 0;
- axes.flat<int32>()(1) = 2;
+ axes.flat<int32>()(1) = 1;
+ axes.flat<int32>()(2) = 2;
test::graph::Reduce(g, reduce, test::graph::Constant(g, data),
test::graph::Constant(g, axes));
return g;
@@ -83,100 +37,51 @@ static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) {
// Creates a bench which reduces a 3D tensor with total "num" floats
// into a scalar on a "device". Runs the bench for "iters" times.
-template <typename T>
static void ReduceToScalar(int iters, const string& device,
- const string& reduce, int num_x, int num_y) {
- testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
- testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
- sizeof(T));
- test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y)).Run(iters);
-}
-
-static void DoRowReduce(int iters, const string& device, const string& reduce,
- int num_x, int num_y) {
- testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
- testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
- sizeof(float));
- test::Benchmark(device, RowReduce(reduce, num_x, num_y)).Run(iters);
-}
-
-static void DoColReduce(int iters, const string& device, const string& reduce,
- int num_x, int num_y) {
- testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
- testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
- sizeof(float));
- test::Benchmark(device, ColReduce(reduce, num_x, num_y)).Run(iters);
-}
-
-static void Do3DYReduce(int iters, const string& device, const string& reduce,
- int num_x, int num_y) {
- testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
- testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
- sizeof(float));
- test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y)).Run(iters);
-}
-
-static void Do3DXZReduce(int iters, const string& device, const string& reduce,
- int num_x, int num_y) {
- testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
- testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
- sizeof(float));
- test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y)).Run(iters);
-}
-
-static void BM_Sum2DToScalarGPU(int iters, int num_x, int num_y) {
- ReduceToScalar<float>(iters, "gpu", "Sum", num_x, num_y);
-}
-BENCHMARK(BM_Sum2DToScalarGPU)->RangePair(1, 8192, 1, 8192);
-
-static void BM_Sum2DToScalarGPUComplex(int iters, int num_x, int num_y) {
- ReduceToScalar<std::complex<float>>(iters, "gpu", "Sum", num_x, num_y);
-}
-BENCHMARK(BM_Sum2DToScalarGPUComplex)->RangePair(1, 8192, 1, 8192);
-
-static void BM_Sum2DToScalarGPUHalf(int iters, int num_x, int num_y) {
- ReduceToScalar<Eigen::half>(iters, "gpu", "Sum", num_x, num_y);
+ const string& reduce, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num * sizeof(float));
+ test::Benchmark(device, ToScalar(reduce, num)).Run(iters);
}
-BENCHMARK(BM_Sum2DToScalarGPUHalf)->RangePair(1, 8192, 1, 8192);
-static void BM_Sum2DRowReduceGPU(int iters, int num_x, int num_y) {
- DoRowReduce(iters, "gpu", "Sum", num_x, num_y);
+static void BM_Sum3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Sum", num);
}
-BENCHMARK(BM_Sum2DRowReduceGPU)->RangePair(1, 8192, 1, 8192);
+BENCHMARK(BM_Sum3DToScalarCPU)->Range(1 << 13, 1 << 20);
-static void BM_Sum2DColumnReduceGPU(int iters, int num_x, int num_y) {
- DoColReduce(iters, "gpu", "Sum", num_x, num_y);
+static void BM_Max3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Max", num);
}
-BENCHMARK(BM_Sum2DColumnReduceGPU)->RangePair(1, 8192, 1, 8192);
+BENCHMARK(BM_Max3DToScalarCPU)->Range(1 << 13, 1 << 20);
-static void BM_Sum3DYReduceGPU(int iters, int num_x, int num_y) {
- Do3DYReduce(iters, "gpu", "Sum", num_x, num_y);
+static void BM_Prod3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Prod", num);
}
-BENCHMARK(BM_Sum3DYReduceGPU)->RangePair(64, 4096, 64, 4096);
+BENCHMARK(BM_Prod3DToScalarCPU)->Range(1 << 13, 1 << 20);
-static void BM_Sum3DXZReduceGPU(int iters, int num_x, int num_y) {
- Do3DXZReduce(iters, "gpu", "Sum", num_x, num_y);
+static void BM_Mean3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Mean", num);
}
-BENCHMARK(BM_Sum3DXZReduceGPU)->RangePair(64, 4096, 64, 4096);
+BENCHMARK(BM_Mean3DToScalarCPU)->Range(1 << 13, 1 << 20);
-static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) {
- ReduceToScalar<float>(iters, "gpu", "Mean", num_x, num_y);
+static void BM_Sum3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Sum", num);
}
-BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
+BENCHMARK(BM_Sum3DToScalarGPU)->Range(1 << 13, 1 << 20);
-static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) {
- ReduceToScalar<float>(iters, "gpu", "Max", num_x, num_y);
+static void BM_Max3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Max", num);
}
-BENCHMARK(BM_Max2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
+BENCHMARK(BM_Max3DToScalarGPU)->Range(1 << 13, 1 << 20);
-static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) {
- ReduceToScalar<float>(iters, "gpu", "Min", num_x, num_y);
+static void BM_Prod3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Prod", num);
}
-BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
+BENCHMARK(BM_Prod3DToScalarGPU)->Range(1 << 13, 1 << 20);
-static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) {
- ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y);
+static void BM_Mean3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Mean", num);
}
-BENCHMARK(BM_Bool2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
+BENCHMARK(BM_Mean3DToScalarGPU)->Range(1 << 13, 1 << 20);
} // end namespace tensorflow
diff --git a/tensorflow/core/util/permutation_input_iterator.h b/tensorflow/core/util/permutation_input_iterator.h
deleted file mode 100644
index f6375b2515..0000000000
--- a/tensorflow/core/util/permutation_input_iterator.h
+++ /dev/null
@@ -1,134 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_UTIL_PERMUTATION_INPUT_ITERATOR_H_
-#define TENSORFLOW_UTIL_PERMUTATION_INPUT_ITERATOR_H_
-
-#include <iostream>
-#include <iterator>
-
-namespace tensorflow {
-
-template <typename ValueType, typename InputIteratorT, typename IndexIteratorT,
- typename OffsetT = ptrdiff_t>
-class PermutationInputIterator {
- public:
- // Required iterator traits
- typedef PermutationInputIterator self_type; ///< My own type
- typedef OffsetT difference_type; ///< Type to express the result of
- ///< subtracting one iterator from another
- typedef ValueType
- value_type; ///< The type of the element the iterator can point to
- typedef ValueType* pointer; ///< The type of a pointer to an element the
- ///< iterator can point to
- typedef ValueType reference; ///< The type of a reference to an element the
- ///< iterator can point to
-
- typedef std::random_access_iterator_tag
- iterator_category; ///< The iterator category
-
- private:
- InputIteratorT input_itr;
- IndexIteratorT index_itr;
-
- public:
- /// Constructor
- __host__ __device__ __forceinline__ PermutationInputIterator(
- InputIteratorT input_itr, ///< Input iterator to wrap
- IndexIteratorT index_itr) ///< Conversion functor to wrap
- : input_itr(input_itr), index_itr(index_itr) {}
-
- /// Postfix increment
- __host__ __device__ __forceinline__ self_type operator++(int) {
- self_type retval = *this;
- index_itr++;
- return retval;
- }
-
- /// Prefix increment
- __host__ __device__ __forceinline__ self_type operator++() {
- index_itr++;
- return *this;
- }
-
- /// Indirection
- __host__ __device__ __forceinline__ reference operator*() const {
- return input_itr[*index_itr];
- }
-
- /// Addition
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type operator+(Distance n) const {
- self_type retval(input_itr, index_itr + n);
- return retval;
- }
-
- /// Addition assignment
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type& operator+=(Distance n) {
- index_itr += n;
- return *this;
- }
-
- /// Subtraction
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type operator-(Distance n) const {
- self_type retval(input_itr, index_itr - n);
- return retval;
- }
-
- /// Subtraction assignment
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type& operator-=(Distance n) {
- index_itr -= n;
- return *this;
- }
-
- /// Distance
- __host__ __device__ __forceinline__ difference_type
- operator-(self_type other) const {
- return index_itr - other.index_itr;
- }
-
- /// Array subscript
- template <typename Distance>
- __host__ __device__ __forceinline__ reference operator[](Distance n) const {
- return input_itr[index_itr[n]];
- }
-
- /// Structure dereference
- __host__ __device__ __forceinline__ pointer operator->() {
- return input_itr + *index_itr;
- }
-
- /// Equal to
- __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) {
- return (index_itr == rhs.index_itr && input_itr == rhs.input_itr);
- }
-
- /// Not equal to
- __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) {
- return !(*this == rhs);
- }
-
- /// ostream operator
- friend std::ostream& operator<<(std::ostream& os, const self_type& itr) {
- return os;
- }
-};
-
-} // end namespace tensorflow
-
-#endif // TENSORFLOW_UTIL_PERMUTATION_INPUT_ITERATOR_H_
diff --git a/tensorflow/core/util/transform_output_iterator.h b/tensorflow/core/util/transform_output_iterator.h
deleted file mode 100644
index 1640791ad1..0000000000
--- a/tensorflow/core/util/transform_output_iterator.h
+++ /dev/null
@@ -1,149 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_
-#define TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_
-
-#include <iostream>
-#include <iterator>
-
-namespace tensorflow {
-
-template <typename StoreType, typename InputType, typename ConversionOp,
- typename OffsetT = ptrdiff_t>
-class TransformOutputIterator {
- private:
- // Proxy object
- struct Reference {
- StoreType* ptr;
- ConversionOp conversion_op;
-
- /// Constructor
- __host__ __device__ __forceinline__ Reference(StoreType* ptr,
- ConversionOp conversion_op)
- : ptr(ptr), conversion_op(conversion_op) {}
-
- /// Assignment
- __host__ __device__ __forceinline__ InputType operator=(InputType val) {
- *ptr = conversion_op(val);
- return val;
- }
- };
-
- public:
- // Required iterator traits
- typedef TransformOutputIterator self_type; ///< My own type
- typedef OffsetT difference_type; ///< Type to express the result of
- ///< subtracting one iterator from another
- typedef void
- value_type; ///< The type of the element the iterator can point to
- typedef void pointer; ///< The type of a pointer to an element the iterator
- ///< can point to
- typedef Reference reference; ///< The type of a reference to an element the
- ///< iterator can point to
-
- typedef std::random_access_iterator_tag
- iterator_category; ///< The iterator category
-
- /*private:*/
-
- StoreType* ptr;
- ConversionOp conversion_op;
-
- public:
- /// Constructor
- template <typename QualifiedStoreType>
- __host__ __device__ __forceinline__ TransformOutputIterator(
- QualifiedStoreType* ptr,
- ConversionOp conversionOp) ///< Native pointer to wrap
- : ptr(ptr), conversion_op(conversionOp) {}
-
- /// Postfix increment
- __host__ __device__ __forceinline__ self_type operator++(int) {
- self_type retval = *this;
- ptr++;
- return retval;
- }
-
- /// Prefix increment
- __host__ __device__ __forceinline__ self_type operator++() {
- ptr++;
- return *this;
- }
-
- /// Indirection
- __host__ __device__ __forceinline__ reference operator*() const {
- return Reference(ptr, conversion_op);
- }
-
- /// Addition
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type operator+(Distance n) const {
- self_type retval(ptr + n, conversion_op);
- return retval;
- }
-
- /// Addition assignment
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type& operator+=(Distance n) {
- ptr += n;
- return *this;
- }
-
- /// Subtraction
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type operator-(Distance n) const {
- self_type retval(ptr - n, conversion_op);
- return retval;
- }
-
- /// Subtraction assignment
- template <typename Distance>
- __host__ __device__ __forceinline__ self_type& operator-=(Distance n) {
- ptr -= n;
- return *this;
- }
-
- /// Distance
- __host__ __device__ __forceinline__ difference_type
- operator-(self_type other) const {
- return ptr - other.ptr;
- }
-
- /// Array subscript
- template <typename Distance>
- __host__ __device__ __forceinline__ reference operator[](Distance n) const {
- return Reference(ptr + n, conversion_op);
- }
-
- /// Equal to
- __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) {
- return (ptr == rhs.ptr);
- }
-
- /// Not equal to
- __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) {
- return (ptr != rhs.ptr);
- }
-
- /// ostream operator
- friend std::ostream& operator<<(std::ostream& os, const self_type& itr) {
- return os;
- }
-};
-
-} // end namespace tensorflow
-
-#endif // TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index b6daad3ddf..797112b538 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1680,26 +1680,6 @@ cuda_py_test(
)
cuda_py_test(
- name = "reduction_ops_test_big",
- size = "medium",
- srcs = ["reduction_ops_test_big.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:math_ops",
- ],
- tags = [
- "manual",
- "no_gpu",
- "nogpu",
- "noguitar",
- "notap",
- ],
-)
-
-cuda_py_test(
name = "relu_op_test",
size = "small",
srcs = ["relu_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index b98a04d72c..04ce99a4a6 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -175,24 +175,6 @@ class SumReductionTest(BaseReductionTest):
np_arr = self._makeIncremental((2,) * rank, dtypes.int32)
self._compareAllAxes(np_arr)
- def testFloat16(self):
- for rank in range(1, _MAX_RANK + 1):
- np_arr = self._makeIncremental((2,) * rank, dtypes.float16)
- self._compareAllAxes(np_arr)
-
- # test that mean doesn't overflow
- # only on GPU, since it has the more accurate implementation
- if not test.is_gpu_available():
- return
-
- arr = np.ones([68000], dtype=np.float16)
-
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
- tf_arr = array_ops.constant(arr)
- tf_mean = math_ops.reduce_mean(tf_arr, 0, False)
- tf_out_mean = sess.run(tf_mean)
- self.assertAllClose(tf_out_mean, 1.)
-
def testFloat32(self):
for rank in range(1, _MAX_RANK + 1):
np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
deleted file mode 100644
index 99fea62f98..0000000000
--- a/tensorflow/python/kernel_tests/reduction_ops_test_big.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Functional tests for reduction ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class BaseReductionTest(test.TestCase):
-
- def _tf_reduce(self, x, reduction_axes, keep_dims):
- raise NotImplementedError()
-
-
-class SumReductionTest(BaseReductionTest):
-
- def _tf_reduce(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_sum(x, reduction_axes, keep_dims)
-
- def testFloat32(self):
- # make sure we test all possible kernel invocations
- # logic is the same for all ops, test just float32 for brevity
- for size_x in range(1, 4105, 27):
- for size_y in range(1, 4105, 27):
- arr = np.ones([size_x, size_y], dtype=np.float32)
- col_sum = np.ones([size_y], dtype=np.float32) * size_x
- row_sum = np.ones([size_x], dtype=np.float32) * size_y
- full_sum = np.ones([], dtype=np.float32) * size_x * size_y
-
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
- tf_row_sum = self._tf_reduce(arr, 1, False)
- tf_col_sum = self._tf_reduce(arr, 0, False)
- tf_full_sum = self._tf_reduce(arr, [0, 1], False)
- tf_out_row, tf_out_col, tf_out_full = sess.run(
- [tf_row_sum, tf_col_sum, tf_full_sum])
- self.assertAllClose(col_sum, tf_out_col)
- self.assertAllClose(row_sum, tf_out_row)
- self.assertAllClose(full_sum, tf_out_full)
-
- for size_x in range(1, 130, 3):
- for size_y in range(1, 130, 3):
- for size_z in range(1, 130, 3):
- arr = np.ones([size_x, size_y, size_z], dtype=np.float32)
- sum_y = np.sum(arr, axis=1)
- sum_xz = np.sum(arr, axis=(0, 2))
-
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
- tf_sum_xz = self._tf_reduce(arr, [0, 2], False)
- tf_sum_y = self._tf_reduce(arr, 1, False)
- tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
- self.assertAllClose(sum_y, tf_out_sum_y)
- self.assertAllClose(sum_xz, tf_out_sum_xz)
-
-
-if __name__ == "__main__":
- test.main()