diff options
Diffstat (limited to 'tensorflow/core/kernels')
22 files changed, 538 insertions, 336 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 960c653593..8d5ed3c2fe 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -55,6 +55,24 @@ struct CastFunctor<CPUDevice, O, I> { } // namespace functor +#define CURRY_TYPES2(FN, arg0) \ + FN(arg0, bool); \ + FN(arg0, uint8); \ + FN(arg0, int16); \ + FN(arg0, int32); \ + FN(arg0, int64); \ + FN(arg0, float); \ + FN(arg0, double) + +#define CURRY_TYPES3(FN, arg0, arg1) \ + FN(arg0, arg1, bool); \ + FN(arg0, arg1, uint8); \ + FN(arg0, arg1, int16); \ + FN(arg0, arg1, int32); \ + FN(arg0, arg1, int64); \ + FN(arg0, arg1, float); \ + FN(arg0, arg1, double) + #define CAST_CASE(DEVICE, IN, OUT) \ if (DataTypeToEnum<IN>::value == src_dtype_ && \ DataTypeToEnum<OUT>::value == dst_dtype_) { \ @@ -110,27 +128,14 @@ class CpuCastOp : public CastOpBase { work_ = nullptr; // Identity return Status::OK(); } - CAST_CASE(CPUDevice, bool, float); - CAST_CASE(CPUDevice, bool, int32); - CAST_CASE(CPUDevice, bool, double); - CAST_CASE(CPUDevice, double, float); - CAST_CASE(CPUDevice, double, int32); - CAST_CASE(CPUDevice, double, int64); - CAST_CASE(CPUDevice, float, double); - CAST_CASE(CPUDevice, float, uint8); - CAST_CASE(CPUDevice, float, int32); - CAST_CASE(CPUDevice, float, int64); - CAST_CASE(CPUDevice, int32, double); - CAST_CASE(CPUDevice, int32, float); - CAST_CASE(CPUDevice, int32, uint8); - CAST_CASE(CPUDevice, int32, int64); - CAST_CASE(CPUDevice, int64, double); - CAST_CASE(CPUDevice, int64, float); - CAST_CASE(CPUDevice, int64, int32); - CAST_CASE(CPUDevice, uint8, float); - CAST_CASE(CPUDevice, uint8, int32); - CAST_CASE(CPUDevice, uint8, int64); - CAST_CASE(CPUDevice, uint8, double); + CURRY_TYPES3(CAST_CASE, CPUDevice, bool); + CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); + CURRY_TYPES3(CAST_CASE, CPUDevice, int16); + CURRY_TYPES3(CAST_CASE, CPUDevice, int32); + CURRY_TYPES3(CAST_CASE, CPUDevice, int64); + CURRY_TYPES3(CAST_CASE, CPUDevice, float); + CURRY_TYPES3(CAST_CASE, CPUDevice, double); + if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) { work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { int64 N = out->NumElements(); @@ -185,24 +190,15 @@ class GpuCastOp : public CastOpBase { work_ = nullptr; // Identity return Status::OK(); } - CAST_CASE(GPUDevice, bfloat16, float); - CAST_CASE(GPUDevice, bool, float); - CAST_CASE(GPUDevice, double, float); - CAST_CASE(GPUDevice, double, int64); + CURRY_TYPES3(CAST_CASE, GPUDevice, bool); + CURRY_TYPES3(CAST_CASE, GPUDevice, uint8); + CURRY_TYPES3(CAST_CASE, GPUDevice, int16); + CURRY_TYPES3(CAST_CASE, GPUDevice, int32); + CURRY_TYPES3(CAST_CASE, GPUDevice, int64); + CURRY_TYPES3(CAST_CASE, GPUDevice, float); + CURRY_TYPES3(CAST_CASE, GPUDevice, double); CAST_CASE(GPUDevice, float, bfloat16); - CAST_CASE(GPUDevice, float, double); - CAST_CASE(GPUDevice, float, int64); - CAST_CASE(GPUDevice, int64, double); - CAST_CASE(GPUDevice, int64, float); - CAST_CASE(GPUDevice, uint8, float); - CAST_CASE(GPUDevice, float, uint8); - CAST_CASE(GPUDevice, bool, int32); - CAST_CASE(GPUDevice, double, int32); - CAST_CASE(GPUDevice, float, int32); - CAST_CASE(GPUDevice, int32, double); - CAST_CASE(GPUDevice, int32, float); - CAST_CASE(GPUDevice, int32, int64); - CAST_CASE(GPUDevice, int64, int32); + CAST_CASE(GPUDevice, bfloat16, float); return Unimplemented(); } }; @@ -217,28 +213,24 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); .TypeConstraint<srctype>("SrcT") \ .TypeConstraint<dsttype>("DstT") \ .Device(DEVICE_GPU), \ - GpuCastOp); -REGISTER_CAST_GPU(bfloat16, float); -REGISTER_CAST_GPU(bool, float); -REGISTER_CAST_GPU(double, float); -REGISTER_CAST_GPU(double, int64); + GpuCastOp) + +CURRY_TYPES2(REGISTER_CAST_GPU, bool); +CURRY_TYPES2(REGISTER_CAST_GPU, uint8); +CURRY_TYPES2(REGISTER_CAST_GPU, int16); +CURRY_TYPES2(REGISTER_CAST_GPU, int32); +CURRY_TYPES2(REGISTER_CAST_GPU, int64); +CURRY_TYPES2(REGISTER_CAST_GPU, float); +CURRY_TYPES2(REGISTER_CAST_GPU, double); REGISTER_CAST_GPU(float, bfloat16); -REGISTER_CAST_GPU(float, double); -REGISTER_CAST_GPU(float, int64); -REGISTER_CAST_GPU(int64, double); -REGISTER_CAST_GPU(int64, float); -REGISTER_CAST_GPU(uint8, float); -REGISTER_CAST_GPU(float, uint8); -REGISTER_CAST_GPU(bool, int32); -REGISTER_CAST_GPU(double, int32); -REGISTER_CAST_GPU(float, int32); -REGISTER_CAST_GPU(int32, double); -REGISTER_CAST_GPU(int32, float); -REGISTER_CAST_GPU(int32, int64); -REGISTER_CAST_GPU(int64, int32); +REGISTER_CAST_GPU(bfloat16, float); + #undef REGISTER_CAST_GPU #endif // GOOGLE_CUDA +#undef CURRY_TYPES2 +#undef CURRY_TYPES3 + // HostCast differs from Cast in that its input and output are in host memory. REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 43f8cd90ed..57f0873621 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -33,25 +33,27 @@ struct CastFunctor<GPUDevice, O, I> { } }; -#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>; -DEFINE(float, double); -DEFINE(float, int32); -DEFINE(float, int64); -DEFINE(double, float); -DEFINE(double, int32); -DEFINE(double, int64); -DEFINE(int32, float); -DEFINE(int32, double); -DEFINE(int32, int64); -DEFINE(int64, float); -DEFINE(int64, double); -DEFINE(int64, int32); -DEFINE(int32, bool); -DEFINE(float, bool); -DEFINE(float, uint8); -DEFINE(uint8, float); -DEFINE(float, bfloat16); +#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I> +#define DEFINE_ALL_FROM(in_type) \ + DEFINE(in_type, bool); \ + DEFINE(in_type, uint8); \ + DEFINE(in_type, int16); \ + DEFINE(in_type, int32); \ + DEFINE(in_type, int64); \ + DEFINE(in_type, float); \ + DEFINE(in_type, double) + +DEFINE_ALL_FROM(bool); +DEFINE_ALL_FROM(uint8); +DEFINE_ALL_FROM(int16); +DEFINE_ALL_FROM(int32); +DEFINE_ALL_FROM(int64); +DEFINE_ALL_FROM(float); +DEFINE_ALL_FROM(double); DEFINE(bfloat16, float); +DEFINE(float, bfloat16); + +#undef DEFINE_ALL_FROM #undef DEFINE } // end namespace functor diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index b93c0857db..168914f553 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -41,22 +41,48 @@ class CastOpTest : public OpsTestBase { void MakeOp(DataType src, DataType dst) { RequireDefaultOps(); EXPECT_OK(NodeDefBuilder("cast_op", "Cast") - .Input(FakeInput(DT_INT32)) + .Input(FakeInput(src)) .Attr("SrcT", src) .Attr("DstT", dst) .Finalize(node_def())); EXPECT_OK(InitOp()); } + + template <typename IN, typename OUT> + void CheckCast() { + DataType in_type = DataTypeToEnum<IN>::v(); + DataType out_type = DataTypeToEnum<OUT>::v(); + MakeOp(in_type, out_type); + AddInputFromArray<IN>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), out_type, TensorShape({1, 2, 2, 1})); + test::FillValues<OUT>(&expected, {1, 2, 3, 4}); + test::ExpectTensorEqual<OUT>(expected, *GetOutput(0)); + } }; -TEST_F(CastOpTest, Int32ToUint8) { - MakeOp(DT_INT32, DT_UINT8); - AddInputFromArray<int32>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); - ASSERT_OK(RunOpKernel()); - Tensor expected(allocator(), DT_UINT8, TensorShape({1, 2, 2, 1})); - test::FillValues<uint8>(&expected, {1, 2, 3, 4}); - test::ExpectTensorEqual<uint8>(expected, *GetOutput(0)); -} +#define TEST_CAST(in, out) \ + TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } + +#define TEST_ALL_CASTS_FROM(in) \ + TEST_CAST(in, uint8); \ + TEST_CAST(in, int16); \ + TEST_CAST(in, int32); \ + TEST_CAST(in, int64); \ + TEST_CAST(in, float); \ + TEST_CAST(in, double) + +TEST_ALL_CASTS_FROM(uint8) +TEST_ALL_CASTS_FROM(int16) +TEST_ALL_CASTS_FROM(int32) +TEST_ALL_CASTS_FROM(int64) +TEST_ALL_CASTS_FROM(float) +TEST_ALL_CASTS_FROM(double) + +#undef TEST_ALL_CASTS_FROM +#undef TEST_CAST + +// TODO(wicke): check conversions from/to bool, and bfloat16 static void BM_cpu_float_int64(int iters, int num) { testing::ItemsProcessed(static_cast<int64>(iters) * num); diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc index 581171c6ba..084ca9a764 100644 --- a/tensorflow/core/kernels/concat_op_gpu.cu.cc +++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc @@ -34,10 +34,12 @@ void ConcatGPU(const GPUDevice& d, const std::vector< std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs, typename TTypes<T, 2>::Matrix* output) { - Eigen::array<Eigen::DenseIndex, 2> offset(0, 0); + Eigen::array<int32, 2> offset{0, 0}; for (int i = 0; i < inputs.size(); ++i) { - Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions(); - output->slice(offset, size).device(d) = *inputs[i]; + Eigen::array<int32_t, 2> size; + size[0] = inputs[i]->dimension(0); + size[1] = inputs[i]->dimension(1); + To32Bit(*output).slice(offset, size).device(d) = To32Bit(*inputs[i]); offset[1] += size[1]; } } diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc index 5991391850..bbb7a0ee28 100644 --- a/tensorflow/core/kernels/constant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc @@ -73,7 +73,7 @@ struct FillFunctor<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, typename TTypes<T>::ConstScalar in) { Eigen::internal::scalar_const_op<T> f(in.data()); - out.device(d) = out.nullaryExpr(f); + To32Bit(out).device(d) = To32Bit(out).nullaryExpr(f); } }; @@ -91,7 +91,7 @@ DEFINE_FILL_GPU(int64); template <typename T> struct SetZeroFunctor<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) { - out.device(d) = out.constant(0); + To32Bit(out).device(d) = To32Bit(out).constant(0); } }; diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index dae06f4bfc..8bd13b4be3 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -242,13 +242,13 @@ typedef Eigen::GpuDevice GPUDevice; const auto expanded_out_cols = (output_cols - 1) * stride + 1; \ const auto padded_out_rows = input_rows + filter_rows - 1; \ const auto padded_out_cols = input_cols + filter_cols - 1; \ - const auto top_pad_rows = filter_rows - 1 - pad_rows; \ - const auto left_pad_cols = filter_cols - 1 - pad_cols; \ - const auto bottom_pad_rows = \ + const int top_pad_rows = filter_rows - 1 - pad_rows; \ + const int left_pad_cols = filter_cols - 1 - pad_cols; \ + const int bottom_pad_rows = \ padded_out_rows - expanded_out_rows - top_pad_rows; \ - const auto right_pad_cols = \ + const int right_pad_cols = \ padded_out_cols - expanded_out_cols - left_pad_cols; \ - Eigen::DSizes<Eigen::DenseIndex, 4> strides{1, stride, stride, 1}; \ + Eigen::DSizes<int, 4> strides{1, stride, stride, 1}; \ VLOG(2) << "Conv2d: " << label \ << ": expanded_out_rows = " << expanded_out_rows \ << ", expanded_out_cols = " << expanded_out_cols \ @@ -809,9 +809,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_output(0, input_shape, &in_backprop)); const int padding_rows = - (output_rows - 1) * stride + filter_rows - input_rows; + (padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows - + input_rows; const int padding_cols = - (output_cols - 1) * stride + filter_cols - input_cols; + (padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols - + input_cols; // TODO(keveman): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -954,16 +956,17 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_temp(DataTypeToEnum<T>::v(), padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 4> trivial_order{0, 1, 2, 3}; - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{ + Eigen::DSizes<int, 4> trivial_order{0, 1, 2, 3}; + Eigen::array<Eigen::IndexPair<int>, 4> pad_dims{ {{0, 0}, {top_pad_rows, bottom_pad_rows}, {left_pad_cols, right_pad_cols}, {0, 0}}}; - functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides, - pad_dims, trivial_order, padded_output.tensor<T, 4>()); + functor::InflatePadAndShuffle<Device, T, 4, int>()( + context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()), + strides, pad_dims, trivial_order, + To32Bit(padded_output.tensor<T, 4>())); const Tensor& padded_output_cref = padded_output; // We then need to fill a new "reverted" filter @@ -976,11 +979,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_temp(DataTypeToEnum<T>::v(), r_filter_shape, &r_filter)); - Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{0, 1, 3, 2}; + Eigen::DSizes<int, 4> filter_order{0, 1, 3, 2}; Eigen::array<bool, 4> filter_rev_dims{true, true, false, false}; - functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter.tensor<T, 4>(), filter_order, - filter_rev_dims, r_filter.tensor<T, 4>()); + functor::ShuffleAndReverse<Device, T, 4, int>()( + context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()), + filter_order, filter_rev_dims, To32Bit(r_filter.tensor<T, 4>())); const Tensor& r_filter_cref = r_filter; // Now we can call conv_2d directly. @@ -1039,20 +1042,22 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context->allocate_output(0, filter_shape, &filter_backprop)); const int padding_rows = - (output_rows - 1) * stride + filter_rows - input_rows; + (padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows - + input_rows; const int padding_cols = - (output_cols - 1) * stride + filter_cols - input_cols; + (padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols - + input_cols; // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts // supporting different padding. - bool padding_compatible = - (padding_rows % 2 == 0) && (padding_cols % 2 == 0); + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); auto* stream = context->op_device_context<GPUDeviceContext>()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (use_cudnn_ && padding_compatible) { + if (use_cudnn_) { if (filter_rows == 1 && filter_cols == 1 && stride == 1) { const uint64 m = in_depth; const uint64 k = batch * input_rows * input_cols; @@ -1089,10 +1094,31 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { return; } + Tensor compatible_input; + if (rows_odd || cols_odd) { + // If a padding dimension is odd, we have one more element on the right + // side or the bottom side. This is unsupported in cudnn. Therefore, + // we pad that extra element and make it compatible. + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({input.dim_size(0), input.dim_size(1) + rows_odd, + input.dim_size(2) + cols_odd, input.dim_size(3)}), + &compatible_input)); + + functor::PadInput<GPUDevice, T, int>()( + context->template eigen_device<GPUDevice>(), + To32Bit(input.tensor<T, 4>()), 0, rows_odd, 0, cols_odd, + To32Bit(compatible_input.tensor<T, 4>())); + } else { + compatible_input = input; + } + perftools::gputools::dnn::BatchDescriptor input_desc; input_desc.set_count(batch) - .set_height(input_rows) - .set_width(input_cols) + .set_height(compatible_input.dim_size(1)) + .set_width(compatible_input.dim_size(2)) .set_feature_map_count(in_depth) .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::BatchDescriptor output_desc; @@ -1146,14 +1172,19 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { transformed_out_backprop.tensor<T, 4>()); Tensor transformed_input; - OP_REQUIRES_OK(context, - context->allocate_temp( - DataTypeToEnum<T>::value, - TensorShape({batch, in_depth, input_rows, input_cols}), - &transformed_input)); - functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(), - input.tensor<T, 4>(), - transformed_input.tensor<T, 4>()); + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({ + compatible_input.dim_size(0), compatible_input.dim_size(3), + compatible_input.dim_size(1), compatible_input.dim_size(2), + }), + &transformed_input)); + functor::NHWCToNCHW<Device, T>()( + context->eigen_device<Device>(), + const_cast<const Tensor&>(compatible_input).tensor<T, 4>(), + transformed_input.tensor<T, 4>()); auto out_backprop_ptr = AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), @@ -1193,7 +1224,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // [batch, out_rows, out_cols, out_depth] // And we need to change it to // [out_depth, out_rows, out_cols, batch] - Eigen::DSizes<Eigen::DenseIndex, 4> out_order{3, 1, 2, 0}; + Eigen::DSizes<int, 4> out_order{3, 1, 2, 0}; TensorShape padded_out_shape( {out_depth, padded_out_rows, padded_out_cols, batch}); Tensor padded_output; @@ -1201,14 +1232,14 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context->allocate_temp(DataTypeToEnum<T>::v(), padded_out_shape, &padded_output)); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{ + Eigen::array<Eigen::IndexPair<int>, 4> pad_dims{ {{0, 0}, {top_pad_rows, bottom_pad_rows}, {left_pad_cols, right_pad_cols}, {0, 0}}}; - functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides, - pad_dims, out_order, padded_output.tensor<T, 4>()); + functor::InflatePadAndShuffle<Device, T, 4, int>()( + context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()), + strides, pad_dims, out_order, To32Bit(padded_output.tensor<T, 4>())); const Tensor& padded_output_cref = padded_output; // For the backprop of the filter, we need to transpose the input. @@ -1216,7 +1247,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // [batch, in_rows, in_cols, in_depth] // And we need to change it to // [in_rows, in_cols, batch, in_depth] - Eigen::DSizes<Eigen::DenseIndex, 4> in_order{1, 2, 0, 3}; + Eigen::DSizes<int, 4> in_order{1, 2, 0, 3}; TensorShape in_shuffle_shape({input_rows, input_cols, batch, in_depth}); Tensor in_shuffle; OP_REQUIRES_OK(context, @@ -1225,9 +1256,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // No need for reversing this time. Eigen::array<bool, 4> trivial_dims{false, false, false, false}; - functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()( - context->eigen_device<Device>(), input.tensor<T, 4>(), in_order, - trivial_dims, in_shuffle.tensor<T, 4>()); + functor::ShuffleAndReverse<Device, T, 4, int>()( + context->eigen_device<Device>(), To32Bit(input.tensor<T, 4>()), + in_order, trivial_dims, To32Bit(in_shuffle.tensor<T, 4>())); const Tensor& in_shuffle_cref = in_shuffle; // The output of the conv_2d would be @@ -1250,12 +1281,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { BrainPadding2EigenPadding(VALID)); // Now copy the filter_backprop back to the destination. - Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{1, 2, 3, 0}; + Eigen::DSizes<int, 4> filter_order{1, 2, 3, 0}; Eigen::array<bool, 4> filter_rev_dims{true, true, false, false}; const Tensor& filter_shuffle_cref = filter_shuffle; - functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 4>(), - filter_order, filter_rev_dims, filter_backprop->tensor<T, 4>()); + functor::ShuffleAndReverse<Device, T, 4, int>()( + context->eigen_device<Device>(), + To32Bit(filter_shuffle_cref.tensor<T, 4>()), filter_order, + filter_rev_dims, To32Bit(filter_backprop->tensor<T, 4>())); } } @@ -1271,25 +1303,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ - void ShuffleAndReverse<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \ - const GPUDevice& d, \ - typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \ - const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \ - const Eigen::array<bool, 4>& reverse_dims, \ - typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \ - extern template struct ShuffleAndReverse<GPUDevice, T, 4, \ - Eigen::DenseIndex>; \ - template <> \ - void InflatePadAndShuffle<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \ - const GPUDevice& d, \ - typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \ - const Eigen::DSizes<Eigen::DenseIndex, 4>& strides, \ - const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4>& pad_dims, \ - const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \ - typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \ - extern template struct InflatePadAndShuffle<GPUDevice, T, 4, \ - Eigen::DenseIndex>; \ - template <> \ void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \ const Eigen::DSizes<int, 4>& order, \ @@ -1328,7 +1341,13 @@ namespace functor { typename TTypes<T, 4>::ConstTensor filter, \ typename TTypes<T, 4>::ConstTensor output_backprop, int input_rows, \ int input_cols, int stride); \ - extern template struct SpatialConvolutionBackwardInput<GPUDevice, T> + extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>; \ + template <> \ + void PadInput<GPUDevice, T, int>::operator()( \ + const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + int padding_rows_left, int padding_rows_right, int padding_cols_left, \ + int padding_cols_right, typename TTypes<T, 4, int>::Tensor out); \ + extern template struct PadInput<GPUDevice, T, int>; DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC diff --git a/tensorflow/core/kernels/conv_ops_gpu.cu.cc b/tensorflow/core/kernels/conv_ops_gpu.cu.cc index 60ff6b0024..e4ee058406 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cu.cc @@ -33,12 +33,8 @@ struct SpatialConvolution<GPUDevice, T> { typename TTypes<T, 4>::ConstTensor input, typename TTypes<T, 4>::ConstTensor filter, int stride, const Eigen::PaddingType& padding) { - // TODO(keveman): nvcc 6.5 crashes when 32 bit indexing is turned on. Enable - // this when we move to cuda 7.0. - // SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), - // To32Bit(filter), stride, padding); - - SpatialConvolutionFunc(d, output, input, filter, stride, padding); + SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), To32Bit(filter), + stride, padding); } }; diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index bc2b62375f..8fed594b25 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -16,21 +16,11 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER5(BinaryOp, CPU, "Div", functor::div, float, double, int32, int64, - complex64); +REGISTER7(BinaryOp, CPU, "Div", functor::div, float, double, uint8, int16, + int32, int64, complex64); #if GOOGLE_CUDA -REGISTER3(BinaryOp, GPU, "Div", functor::div, float, double, int64); +REGISTER6(BinaryOp, GPU, "Div", functor::div, float, double, uint8, int16, + int32, int64); #endif -// A special GPU kernel for int32. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Div") - .Device(DEVICE_GPU) - .HostMemory("x") - .HostMemory("y") - .HostMemory("z") - .TypeConstraint<int32>("T"), - BinaryOp<CPUDevice, functor::div<int32>>); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc index 80a02da651..a2809d5481 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY3(div, float, double, int64); +DEFINE_BINARY6(div, float, double, uint8, int16, int32, int64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc index a4ecaf185a..068003b294 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY3(mul, float, double, int64); +DEFINE_BINARY7(mul, float, double, uint8, int8, int16, int32, int64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc index a7b9859b19..42d50358e6 100644 --- a/tensorflow/core/kernels/cwise_op_mul.cc +++ b/tensorflow/core/kernels/cwise_op_mul.cc @@ -16,21 +16,11 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER7(BinaryOp, CPU, "Mul", functor::mul, float, double, int32, int64, int8, - int16, complex64); +REGISTER8(BinaryOp, CPU, "Mul", functor::mul, float, double, uint8, int8, int16, + int32, int64, complex64); #if GOOGLE_CUDA -REGISTER3(BinaryOp, GPU, "Mul", functor::mul, float, double, int64); +REGISTER7(BinaryOp, GPU, "Mul", functor::mul, float, double, uint8, int8, int16, + int32, int64); #endif -// A special GPU kernel for int32. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Mul") - .Device(DEVICE_GPU) - .HostMemory("x") - .HostMemory("y") - .HostMemory("z") - .TypeConstraint<int32>("T"), - BinaryOp<CPUDevice, functor::mul<int32>>); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 3296826d48..adf4203322 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -379,6 +379,8 @@ struct SelectFunctor<CPUDevice, T> { #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0) #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ REGISTER(OP, D, N, F, T0) +#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ + REGISTER(OP, D, N, F, T0) #else // !defined(__ANDROID__) #define REGISTER2(OP, D, N, F, T0, T1) \ REGISTER(OP, D, N, F, T0) \ @@ -398,6 +400,9 @@ struct SelectFunctor<CPUDevice, T> { #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ REGISTER3(OP, D, N, F, T4, T5, T6) +#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ + REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER4(OP, D, N, F, T4, T5, T6, T7) #endif // defined(__ANDROID__) } // end namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h index 966d3393b6..091c6717dc 100644 --- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h +++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h @@ -40,7 +40,7 @@ template <typename Functor> struct UnaryFunctor<GPUDevice, Functor> { void operator()(const GPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in) { - out.device(d) = in.unaryExpr(typename Functor::func()); + To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func()); } }; @@ -50,7 +50,8 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> { void operator()(const GPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1) { - out.device(d) = in0.binaryExpr(in1, typename Functor::func()); + To32Bit(out).device(d) = + To32Bit(in0).binaryExpr(in1, typename Functor::func()); } void Left(const GPUDevice& d, typename Functor::tout_type out, @@ -60,7 +61,7 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> { typedef typename Functor::in_type Tin; typedef typename Functor::func Binary; typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary; - out.device(d) = in.unaryExpr(Unary(scalar.data())); + To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data())); } void Right(const GPUDevice& d, typename Functor::tout_type out, @@ -70,7 +71,7 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> { typedef typename Functor::in_type Tin; typedef typename Functor::func Binary; typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary; - out.device(d) = in.unaryExpr(Unary(scalar.data())); + To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data())); } void BCast(const GPUDevice& d, @@ -86,16 +87,18 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> { const bool bcast0_all_one = AllOne<NDIMS>(bcast0); const bool bcast1_all_one = AllOne<NDIMS>(bcast1); if (bcast0_all_one && !bcast1_all_one) { - out.device(d) = in0.binaryExpr(in1.broadcast(bcast1), func); + To32Bit(out).device(d) = + To32Bit(in0).binaryExpr(To32Bit(in1).broadcast(bcast1), func); return; } if (!bcast0_all_one && bcast1_all_one) { - out.device(d) = in0.broadcast(bcast0).binaryExpr(in1, func); + To32Bit(out).device(d) = + To32Bit(in0).broadcast(bcast0).binaryExpr(To32Bit(in1), func); return; } } - out.device(d) = - in0.broadcast(bcast0).binaryExpr(in1.broadcast(bcast1), func); + To32Bit(out).device(d) = To32Bit(in0).broadcast(bcast0).binaryExpr( + To32Bit(in1).broadcast(bcast1), func); } }; @@ -105,7 +108,8 @@ struct SelectFunctor<GPUDevice, T> { typename TTypes<bool>::ConstFlat cond_flat, typename TTypes<T>::ConstFlat then_flat, typename TTypes<T>::ConstFlat else_flat) { - out.device(d) = cond_flat.select(then_flat, else_flat); + To32Bit(out).device(d) = + To32Bit(cond_flat).select(To32Bit(then_flat), To32Bit(else_flat)); } }; @@ -143,6 +147,12 @@ struct SelectFunctor<GPUDevice, T> { #define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \ DEFINE_BINARY2(F, T0, T1); \ DEFINE_BINARY3(F, T2, T3, T4) +#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \ + DEFINE_BINARY3(F, T0, T1, T2); \ + DEFINE_BINARY3(F, T3, T4, T5) +#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \ + DEFINE_BINARY3(F, T0, T1, T2); \ + DEFINE_BINARY4(F, T3, T4, T5, T6) } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index fb779f2466..9ae2eedb30 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -30,10 +30,17 @@ limitations under the License. namespace tensorflow { +namespace { + +// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the +// main band matrix approach used below. Benchmarks suggest switching to +// MognetLRN when depth > 384. +const int kMognetLRNDepthCutoff = 384; + // Create a depth-by-depth band matrix with 1s along a swath of size (2 * // depth_radius + 1) around the diagonal. -static void GetBandMatrix(int depth, int64 depth_radius, - Eigen::Tensor<float, 2, Eigen::RowMajor>* result) { +void GetBandMatrix(int depth, int64 depth_radius, + Eigen::Tensor<float, 2, Eigen::RowMajor>* result) { result->setZero(); for (int row = 0; row < depth; ++row) { const int begin = std::max<int>(0, row - depth_radius); @@ -44,6 +51,8 @@ static void GetBandMatrix(int depth, int64 depth_radius, } } +} // namespace + class LRNOp : public OpKernel { public: explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) { @@ -69,6 +78,11 @@ class LRNOp : public OpKernel { #if defined(__ANDROID__) MognetLRN(in, batch, rows, cols, depth, output); #else + if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) { + MognetLRN(in, batch, rows, cols, depth, output); + return; + } + const int nodes = cols * rows; auto in_shaped = in.shaped<float, 2>({nodes * batch, depth}); @@ -79,13 +93,16 @@ class LRNOp : public OpKernel { auto out_shaped = output->shaped<float, 2>({nodes * batch, depth}); Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}}; - /// TODO(keveman): Optimize for beta in {0, 1, 0.5} - out_shaped.device(context->eigen_cpu_device()) = - in_shaped / - in_shaped.square() - .contract(multiplier, dims) - .unaryExpr([this](float x) { return bias_ + alpha_ * x; }) - .pow(beta_); + auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_; + if (beta_ == 1.0f) { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * tmp.inverse(); + } else if (beta_ == 0.5f) { + out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt(); + } else { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * (tmp.log() * -beta_).exp(); + } #endif } @@ -104,11 +121,11 @@ class LRNOp : public OpKernel { Eigen::VectorXf padded_square(data_in.rows() + double_depth_radius); padded_square.setZero(); for (int r = 0; r < data_in.cols(); ++r) { - // Do local response normalization for data_in(:, r) - // first, compute the square and store them in buffer for repeated use + // Do local response normalization for data_in(:, r). First, compute the + // square and store them in buffer for repeated use. padded_square.block(depth_radius_, 0, data_out.rows(), 1) = data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_; - // Then, compute the scale and writes them to data_out + // Then, compute the scale and write it to data_out. float accumulated_scale = 0; for (int i = 0; i < double_depth_radius; ++i) { accumulated_scale += padded_square(i); @@ -120,13 +137,13 @@ class LRNOp : public OpKernel { } } - // In a few cases, the pow computation could benefit from speedups. if (beta_ == 1) { data_out.array() = data_in.array() * data_out.array().inverse(); } else if (beta_ == 0.5) { - data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + data_out.array() = data_in.array() * data_out.array().rsqrt(); } else { - data_out.array() = data_in.array() * data_out.array().pow(-beta_); + data_out.array() = + data_in.array() * (data_out.array().log() * -beta_).exp(); } } diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h deleted file mode 100644 index 16fa541238..0000000000 --- a/tensorflow/core/kernels/reference_gemm.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ -#define TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ - -// This is an unoptimized but debuggable implementation of the GEMM matrix -// multiply function, used to compare to faster but more opaque versions, or -// for bit depths or argument combinations that aren't supported by optimized -// code. -// It assumes the row-major convention used by TensorFlow, and implements -// C = A * B, like the standard BLAS GEMM interface. If the tranpose flags are -// true, then the relevant matrix is treated as stored in column-major order. - -namespace tensorflow { -template <class T1, class T2, class T3> -void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c, - size_t m, size_t n, size_t k, const T1* a, T1 offset_a, - size_t lda, const T2* b, T2 offset_b, size_t ldb, T3* c, - int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) { - int a_i_stride; - int a_l_stride; - if (transpose_a) { - a_i_stride = 1; - a_l_stride = lda; - } else { - a_i_stride = lda; - a_l_stride = 1; - } - int b_j_stride; - int b_l_stride; - if (transpose_b) { - b_j_stride = ldb; - b_l_stride = 1; - } else { - b_j_stride = 1; - b_l_stride = ldb; - } - int c_i_stride; - int c_j_stride; - if (transpose_c) { - c_i_stride = 1; - c_j_stride = ldc; - } else { - c_i_stride = ldc; - c_j_stride = 1; - } - - const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest()); - const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest()); - const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1)); - - int i, j, l; - for (j = 0; j < n; j++) { - for (i = 0; i < m; i++) { - int32 total = 0; - for (l = 0; l < k; l++) { - const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); - const int32 a_value = a[a_index] - offset_a; - const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); - const int32 b_value = b[b_index] - offset_b; - total += (a_value * b_value); - } - const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); - int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c); - if (output > highest) { - output = highest; - } - if (output < lowest) { - output = lowest; - } - c[c_index] = static_cast<T3>(output); - } - } -} -} // namespace tensorflow - -#endif // TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index 0671414c51..a25c68a15a 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -39,7 +39,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template <typename Device> -void CheckErrors(OpKernelContext* context, int seq_dim) { +void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); @@ -52,15 +52,18 @@ void CheckErrors(OpKernelContext* context, int seq_dim) { seq_lens_vec.data(), seq_lens_t.data(), sizeof(int64) * seq_lens_t.size()); - OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim")); + OP_REQUIRES(context, batch_dim != seq_dim, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); OP_REQUIRES(context, seq_dim < input.dims(), errors::InvalidArgument("seq_dim must be < input.dims()", "( ", seq_dim, " vs. ", input.dims(), ")")); - - OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0), - errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ", - "(", seq_lens.NumElements(), " vs. ", - input.dim_size(seq_dim))); + OP_REQUIRES(context, batch_dim < input.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim, " vs. ", input.dims(), ")")); + OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, + "), ", "(", seq_lens.NumElements(), + " vs. ", input.dim_size(batch_dim))); for (int d = 0; d < seq_lens_vec.size(); ++d) { OP_REQUIRES(context, seq_lens_vec[d] >= 0, @@ -72,19 +75,24 @@ void CheckErrors(OpKernelContext* context, int seq_dim) { } template <> -void CheckErrors<GPUDevice>(OpKernelContext* context, int seq_dim) { +void CheckErrors<GPUDevice>(OpKernelContext* context, int batch_dim, + int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); - OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim")); + OP_REQUIRES(context, batch_dim != seq_dim, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); OP_REQUIRES(context, seq_dim < input.dims(), errors::InvalidArgument("seq_dim must be < input.dims()", "( ", seq_dim, " vs. ", input.dims(), ")")); - - OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0), - errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ", - "(", seq_lens.NumElements(), " vs. ", - input.dim_size(seq_dim))); + OP_REQUIRES(context, batch_dim < input.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim, " vs. ", input.dims(), ")")); + + OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, + "), ", "(", seq_lens.NumElements(), + " vs. ", input.dim_size(batch_dim))); } template <typename Device, typename T> @@ -92,6 +100,7 @@ class ReverseSequenceOp : public OpKernel { public: explicit ReverseSequenceOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); } @@ -106,7 +115,7 @@ class ReverseSequenceOp : public OpKernel { auto seq_lens_t = seq_lens.vec<int64>(); - CheckErrors<Device>(context, seq_dim_); + CheckErrors<Device>(context, batch_dim_, seq_dim_); const int input_dims = input.dims(); @@ -114,11 +123,11 @@ class ReverseSequenceOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); -#define HANDLE_DIM(NDIM) \ - case NDIM: \ - functor::ReverseSequence<Device, T, NDIM>::Compute( \ - context->eigen_device<Device>(), input.tensor<T, NDIM>(), seq_dim_, \ - seq_lens_t, output->tensor<T, NDIM>()); \ +#define HANDLE_DIM(NDIM) \ + case NDIM: \ + functor::ReverseSequence<Device, T, NDIM>::Compute( \ + context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \ + seq_dim_, seq_lens_t, output->tensor<T, NDIM>()); \ break; switch (input_dims) { @@ -136,6 +145,7 @@ class ReverseSequenceOp : public OpKernel { } private: + int32 batch_dim_; int32 seq_dim_; TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp); @@ -152,12 +162,12 @@ TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE); // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, Dims) \ - template <> \ - void ReverseSequence<GPUDevice, T, Dims>::Compute( \ - const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ - int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \ - typename TTypes<T, Dims>::Tensor output); \ +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void ReverseSequence<GPUDevice, T, Dims>::Compute( \ + const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ + int32 batch_dim, int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \ + typename TTypes<T, Dims>::Tensor output); \ extern template struct ReverseSequence<GPUDevice, T, Dims>; #define DECLARE_GPU_SPECS(T) \ diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h index ceb1b0b880..9dd1e4d01d 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.h +++ b/tensorflow/core/kernels/reverse_sequence_op.h @@ -29,15 +29,19 @@ template <typename T, size_t Dims> class ReverseGenerator { public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE - ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 seq_dim, - TTypes<int64>::ConstVec seq_lengths) - : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {} + ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, + int32 seq_dim, TTypes<int64>::ConstVec seq_lengths) + : input_(input), + batch_dim_(batch_dim), + seq_dim_(seq_dim), + seq_lengths_(seq_lengths) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T operator()(const Eigen::array<Eigen::DenseIndex, Dims>& coords) const { Eigen::array<Eigen::DenseIndex, Dims> new_coords = coords; - if (coords[seq_dim_] < seq_lengths_(coords[0])) { - new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1; + if (coords[seq_dim_] < seq_lengths_(coords[batch_dim_])) { + new_coords[seq_dim_] = + seq_lengths_(coords[batch_dim_]) - coords[seq_dim_] - 1; } return input_(new_coords); @@ -45,6 +49,7 @@ class ReverseGenerator { private: typename TTypes<T, Dims>::ConstTensor input_; + int32 batch_dim_; int32 seq_dim_; TTypes<int64>::ConstVec seq_lengths_; }; @@ -57,9 +62,10 @@ template <typename Device, typename T, size_t Dims> struct ReverseSequence { EIGEN_ALWAYS_INLINE static void Compute( const Device& d, typename TTypes<T, Dims>::ConstTensor input, - int32 seq_dim, TTypes<int64>::ConstVec seq_lengths, + int32 batch_dim, int32 seq_dim, TTypes<int64>::ConstVec seq_lengths, typename TTypes<T, Dims>::Tensor output) { - generator::ReverseGenerator<T, Dims> generator(input, seq_dim, seq_lengths); + generator::ReverseGenerator<T, Dims> generator(input, batch_dim, seq_dim, + seq_lengths); output.device(d) = input.generate(generator); } }; diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc new file mode 100644 index 0000000000..e3480e3594 --- /dev/null +++ b/tensorflow/core/kernels/softsign_op.cc @@ -0,0 +1,112 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/softsign_op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T> +class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> { + public: + using UnaryElementWiseOp<T, SoftsignOp<Device, T>>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Softsign<Device, T> functor; + functor(context->eigen_device<Device>(), input.flat<T>(), + output->flat<T>()); + } +}; + +template <typename Device, typename T> +class SoftsignGradOp + : public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> { + public: + using BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>::BinaryElementWiseOp; + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): inputs that were passed to SoftsignOp() + // OUTPUT: + // gradients to backprop + template <int NDIMS> + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + functor::SoftsignGrad<Device, T> functor; + functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), + output->flat<T>()); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softsign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SoftsignOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SoftsignGradOp<CPUDevice, type>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Softsign<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor features, \ + typename TTypes<T>::Tensor activations); \ + extern template struct Softsign<GPUDevice, T>; \ + \ + template <> \ + void SoftsignGrad<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ + typename TTypes<T>::ConstTensor features, \ + typename TTypes<T>::Tensor backprops); \ + extern template struct SoftsignGrad<GPUDevice, T>; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softsign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SoftsignOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SoftsignGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SoftsignGradOp<GPUDevice, type>); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/softsign_op.h b/tensorflow/core/kernels/softsign_op.h new file mode 100644 index 0000000000..36790a5874 --- /dev/null +++ b/tensorflow/core/kernels/softsign_op.h @@ -0,0 +1,60 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_SOFTSIGN_OP_H_ +#define TENSORFLOW_KERNELS_SOFTSIGN_OP_H_ +// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by +// nvcc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftsignOp to do the computations. +template <typename Device, typename T> +struct Softsign { + // Computes Softsign activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes<T>::ConstTensor features, + typename TTypes<T>::Tensor activations) { + activations.device(d) = + features / (features.abs() + features.constant(1.0f)); + } +}; + +// Functor used by SoftsignGradOp to do the computations. +template <typename Device, typename T> +struct SoftsignGrad { + // Computes SoftsignGrad backprops. + // + // gradients: gradients backpropagated to the Softsign op. + // features: inputs that were passed to the Softsign op. + // backprops: gradients to backpropagate to the Softsign inputs. + void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients, + typename TTypes<T>::ConstTensor features, + typename TTypes<T>::Tensor backprops) { + backprops.device(d) = + gradients / (features.abs() + features.constant(1.0f)).square(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SOFTSIGN_OP_H_ diff --git a/tensorflow/core/kernels/softsign_op_gpu.cu.cc b/tensorflow/core/kernels/softsign_op_gpu.cu.cc new file mode 100644 index 0000000000..4ae941c9f0 --- /dev/null +++ b/tensorflow/core/kernels/softsign_op_gpu.cu.cc @@ -0,0 +1,40 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include <stdio.h> + +#include "tensorflow/core/kernels/softsign_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Definition of the GPU implementations declared in softsign_op.cc. +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Softsign<GPUDevice, T>; \ + template struct functor::SoftsignGrad<GPUDevice, T>; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc index c79410b68c..13463b705b 100644 --- a/tensorflow/core/kernels/split_op_gpu.cu.cc +++ b/tensorflow/core/kernels/split_op_gpu.cu.cc @@ -33,7 +33,7 @@ void Split<Device, T>::operator()( typename TTypes<T, 3>::ConstTensor input, const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices, const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) { - output.device(d) = input.slice(slice_indices, slice_sizes); + To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes); } #define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>; diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index 055050cd34..2c146b3d6c 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // See docs in ../ops/data_flow_ops.cc. #include <limits.h> |