aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/cast_op.cc104
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc44
-rw-r--r--tensorflow/core/kernels/concat_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/constant_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc147
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc18
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_div.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_mul.cc18
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h5
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_common.cu.h28
-rw-r--r--tensorflow/core/kernels/lrn_op.cc47
-rw-r--r--tensorflow/core/kernels/reference_gemm.h90
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc62
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.h20
-rw-r--r--tensorflow/core/kernels/softsign_op.cc112
-rw-r--r--tensorflow/core/kernels/softsign_op.h60
-rw-r--r--tensorflow/core/kernels/softsign_op_gpu.cu.cc40
-rw-r--r--tensorflow/core/kernels/split_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/stack_ops.cc15
22 files changed, 538 insertions, 336 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 960c653593..8d5ed3c2fe 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -55,6 +55,24 @@ struct CastFunctor<CPUDevice, O, I> {
} // namespace functor
+#define CURRY_TYPES2(FN, arg0) \
+ FN(arg0, bool); \
+ FN(arg0, uint8); \
+ FN(arg0, int16); \
+ FN(arg0, int32); \
+ FN(arg0, int64); \
+ FN(arg0, float); \
+ FN(arg0, double)
+
+#define CURRY_TYPES3(FN, arg0, arg1) \
+ FN(arg0, arg1, bool); \
+ FN(arg0, arg1, uint8); \
+ FN(arg0, arg1, int16); \
+ FN(arg0, arg1, int32); \
+ FN(arg0, arg1, int64); \
+ FN(arg0, arg1, float); \
+ FN(arg0, arg1, double)
+
#define CAST_CASE(DEVICE, IN, OUT) \
if (DataTypeToEnum<IN>::value == src_dtype_ && \
DataTypeToEnum<OUT>::value == dst_dtype_) { \
@@ -110,27 +128,14 @@ class CpuCastOp : public CastOpBase {
work_ = nullptr; // Identity
return Status::OK();
}
- CAST_CASE(CPUDevice, bool, float);
- CAST_CASE(CPUDevice, bool, int32);
- CAST_CASE(CPUDevice, bool, double);
- CAST_CASE(CPUDevice, double, float);
- CAST_CASE(CPUDevice, double, int32);
- CAST_CASE(CPUDevice, double, int64);
- CAST_CASE(CPUDevice, float, double);
- CAST_CASE(CPUDevice, float, uint8);
- CAST_CASE(CPUDevice, float, int32);
- CAST_CASE(CPUDevice, float, int64);
- CAST_CASE(CPUDevice, int32, double);
- CAST_CASE(CPUDevice, int32, float);
- CAST_CASE(CPUDevice, int32, uint8);
- CAST_CASE(CPUDevice, int32, int64);
- CAST_CASE(CPUDevice, int64, double);
- CAST_CASE(CPUDevice, int64, float);
- CAST_CASE(CPUDevice, int64, int32);
- CAST_CASE(CPUDevice, uint8, float);
- CAST_CASE(CPUDevice, uint8, int32);
- CAST_CASE(CPUDevice, uint8, int64);
- CAST_CASE(CPUDevice, uint8, double);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, float);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, double);
+
if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) {
work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
int64 N = out->NumElements();
@@ -185,24 +190,15 @@ class GpuCastOp : public CastOpBase {
work_ = nullptr; // Identity
return Status::OK();
}
- CAST_CASE(GPUDevice, bfloat16, float);
- CAST_CASE(GPUDevice, bool, float);
- CAST_CASE(GPUDevice, double, float);
- CAST_CASE(GPUDevice, double, int64);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, bool);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, uint8);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int16);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int32);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int64);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, float);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, double);
CAST_CASE(GPUDevice, float, bfloat16);
- CAST_CASE(GPUDevice, float, double);
- CAST_CASE(GPUDevice, float, int64);
- CAST_CASE(GPUDevice, int64, double);
- CAST_CASE(GPUDevice, int64, float);
- CAST_CASE(GPUDevice, uint8, float);
- CAST_CASE(GPUDevice, float, uint8);
- CAST_CASE(GPUDevice, bool, int32);
- CAST_CASE(GPUDevice, double, int32);
- CAST_CASE(GPUDevice, float, int32);
- CAST_CASE(GPUDevice, int32, double);
- CAST_CASE(GPUDevice, int32, float);
- CAST_CASE(GPUDevice, int32, int64);
- CAST_CASE(GPUDevice, int64, int32);
+ CAST_CASE(GPUDevice, bfloat16, float);
return Unimplemented();
}
};
@@ -217,28 +213,24 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
.TypeConstraint<srctype>("SrcT") \
.TypeConstraint<dsttype>("DstT") \
.Device(DEVICE_GPU), \
- GpuCastOp);
-REGISTER_CAST_GPU(bfloat16, float);
-REGISTER_CAST_GPU(bool, float);
-REGISTER_CAST_GPU(double, float);
-REGISTER_CAST_GPU(double, int64);
+ GpuCastOp)
+
+CURRY_TYPES2(REGISTER_CAST_GPU, bool);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
+CURRY_TYPES2(REGISTER_CAST_GPU, int16);
+CURRY_TYPES2(REGISTER_CAST_GPU, int32);
+CURRY_TYPES2(REGISTER_CAST_GPU, int64);
+CURRY_TYPES2(REGISTER_CAST_GPU, float);
+CURRY_TYPES2(REGISTER_CAST_GPU, double);
REGISTER_CAST_GPU(float, bfloat16);
-REGISTER_CAST_GPU(float, double);
-REGISTER_CAST_GPU(float, int64);
-REGISTER_CAST_GPU(int64, double);
-REGISTER_CAST_GPU(int64, float);
-REGISTER_CAST_GPU(uint8, float);
-REGISTER_CAST_GPU(float, uint8);
-REGISTER_CAST_GPU(bool, int32);
-REGISTER_CAST_GPU(double, int32);
-REGISTER_CAST_GPU(float, int32);
-REGISTER_CAST_GPU(int32, double);
-REGISTER_CAST_GPU(int32, float);
-REGISTER_CAST_GPU(int32, int64);
-REGISTER_CAST_GPU(int64, int32);
+REGISTER_CAST_GPU(bfloat16, float);
+
#undef REGISTER_CAST_GPU
#endif // GOOGLE_CUDA
+#undef CURRY_TYPES2
+#undef CURRY_TYPES3
+
// HostCast differs from Cast in that its input and output are in host memory.
REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 43f8cd90ed..57f0873621 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -33,25 +33,27 @@ struct CastFunctor<GPUDevice, O, I> {
}
};
-#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>;
-DEFINE(float, double);
-DEFINE(float, int32);
-DEFINE(float, int64);
-DEFINE(double, float);
-DEFINE(double, int32);
-DEFINE(double, int64);
-DEFINE(int32, float);
-DEFINE(int32, double);
-DEFINE(int32, int64);
-DEFINE(int64, float);
-DEFINE(int64, double);
-DEFINE(int64, int32);
-DEFINE(int32, bool);
-DEFINE(float, bool);
-DEFINE(float, uint8);
-DEFINE(uint8, float);
-DEFINE(float, bfloat16);
+#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
+#define DEFINE_ALL_FROM(in_type) \
+ DEFINE(in_type, bool); \
+ DEFINE(in_type, uint8); \
+ DEFINE(in_type, int16); \
+ DEFINE(in_type, int32); \
+ DEFINE(in_type, int64); \
+ DEFINE(in_type, float); \
+ DEFINE(in_type, double)
+
+DEFINE_ALL_FROM(bool);
+DEFINE_ALL_FROM(uint8);
+DEFINE_ALL_FROM(int16);
+DEFINE_ALL_FROM(int32);
+DEFINE_ALL_FROM(int64);
+DEFINE_ALL_FROM(float);
+DEFINE_ALL_FROM(double);
DEFINE(bfloat16, float);
+DEFINE(float, bfloat16);
+
+#undef DEFINE_ALL_FROM
#undef DEFINE
} // end namespace functor
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index b93c0857db..168914f553 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -41,22 +41,48 @@ class CastOpTest : public OpsTestBase {
void MakeOp(DataType src, DataType dst) {
RequireDefaultOps();
EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
- .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(src))
.Attr("SrcT", src)
.Attr("DstT", dst)
.Finalize(node_def()));
EXPECT_OK(InitOp());
}
+
+ template <typename IN, typename OUT>
+ void CheckCast() {
+ DataType in_type = DataTypeToEnum<IN>::v();
+ DataType out_type = DataTypeToEnum<OUT>::v();
+ MakeOp(in_type, out_type);
+ AddInputFromArray<IN>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), out_type, TensorShape({1, 2, 2, 1}));
+ test::FillValues<OUT>(&expected, {1, 2, 3, 4});
+ test::ExpectTensorEqual<OUT>(expected, *GetOutput(0));
+ }
};
-TEST_F(CastOpTest, Int32ToUint8) {
- MakeOp(DT_INT32, DT_UINT8);
- AddInputFromArray<int32>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
- ASSERT_OK(RunOpKernel());
- Tensor expected(allocator(), DT_UINT8, TensorShape({1, 2, 2, 1}));
- test::FillValues<uint8>(&expected, {1, 2, 3, 4});
- test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
-}
+#define TEST_CAST(in, out) \
+ TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); }
+
+#define TEST_ALL_CASTS_FROM(in) \
+ TEST_CAST(in, uint8); \
+ TEST_CAST(in, int16); \
+ TEST_CAST(in, int32); \
+ TEST_CAST(in, int64); \
+ TEST_CAST(in, float); \
+ TEST_CAST(in, double)
+
+TEST_ALL_CASTS_FROM(uint8)
+TEST_ALL_CASTS_FROM(int16)
+TEST_ALL_CASTS_FROM(int32)
+TEST_ALL_CASTS_FROM(int64)
+TEST_ALL_CASTS_FROM(float)
+TEST_ALL_CASTS_FROM(double)
+
+#undef TEST_ALL_CASTS_FROM
+#undef TEST_CAST
+
+// TODO(wicke): check conversions from/to bool, and bfloat16
static void BM_cpu_float_int64(int iters, int num) {
testing::ItemsProcessed(static_cast<int64>(iters) * num);
diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc
index 581171c6ba..084ca9a764 100644
--- a/tensorflow/core/kernels/concat_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc
@@ -34,10 +34,12 @@ void ConcatGPU(const GPUDevice& d,
const std::vector<
std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
typename TTypes<T, 2>::Matrix* output) {
- Eigen::array<Eigen::DenseIndex, 2> offset(0, 0);
+ Eigen::array<int32, 2> offset{0, 0};
for (int i = 0; i < inputs.size(); ++i) {
- Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions();
- output->slice(offset, size).device(d) = *inputs[i];
+ Eigen::array<int32_t, 2> size;
+ size[0] = inputs[i]->dimension(0);
+ size[1] = inputs[i]->dimension(1);
+ To32Bit(*output).slice(offset, size).device(d) = To32Bit(*inputs[i]);
offset[1] += size[1];
}
}
diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc
index 5991391850..bbb7a0ee28 100644
--- a/tensorflow/core/kernels/constant_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc
@@ -73,7 +73,7 @@ struct FillFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstScalar in) {
Eigen::internal::scalar_const_op<T> f(in.data());
- out.device(d) = out.nullaryExpr(f);
+ To32Bit(out).device(d) = To32Bit(out).nullaryExpr(f);
}
};
@@ -91,7 +91,7 @@ DEFINE_FILL_GPU(int64);
template <typename T>
struct SetZeroFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
- out.device(d) = out.constant(0);
+ To32Bit(out).device(d) = To32Bit(out).constant(0);
}
};
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index dae06f4bfc..8bd13b4be3 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -242,13 +242,13 @@ typedef Eigen::GpuDevice GPUDevice;
const auto expanded_out_cols = (output_cols - 1) * stride + 1; \
const auto padded_out_rows = input_rows + filter_rows - 1; \
const auto padded_out_cols = input_cols + filter_cols - 1; \
- const auto top_pad_rows = filter_rows - 1 - pad_rows; \
- const auto left_pad_cols = filter_cols - 1 - pad_cols; \
- const auto bottom_pad_rows = \
+ const int top_pad_rows = filter_rows - 1 - pad_rows; \
+ const int left_pad_cols = filter_cols - 1 - pad_cols; \
+ const int bottom_pad_rows = \
padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
+ const int right_pad_cols = \
padded_out_cols - expanded_out_cols - left_pad_cols; \
- Eigen::DSizes<Eigen::DenseIndex, 4> strides{1, stride, stride, 1}; \
+ Eigen::DSizes<int, 4> strides{1, stride, stride, 1}; \
VLOG(2) << "Conv2d: " << label \
<< ": expanded_out_rows = " << expanded_out_rows \
<< ", expanded_out_cols = " << expanded_out_cols \
@@ -809,9 +809,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_output(0, input_shape, &in_backprop));
const int padding_rows =
- (output_rows - 1) * stride + filter_rows - input_rows;
+ (padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows -
+ input_rows;
const int padding_cols =
- (output_cols - 1) * stride + filter_cols - input_cols;
+ (padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols -
+ input_cols;
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -954,16 +956,17 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_temp(DataTypeToEnum<T>::v(),
padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 4> trivial_order{0, 1, 2, 3};
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{
+ Eigen::DSizes<int, 4> trivial_order{0, 1, 2, 3};
+ Eigen::array<Eigen::IndexPair<int>, 4> pad_dims{
{{0, 0},
{top_pad_rows, bottom_pad_rows},
{left_pad_cols, right_pad_cols},
{0, 0}}};
- functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
- pad_dims, trivial_order, padded_output.tensor<T, 4>());
+ functor::InflatePadAndShuffle<Device, T, 4, int>()(
+ context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
+ strides, pad_dims, trivial_order,
+ To32Bit(padded_output.tensor<T, 4>()));
const Tensor& padded_output_cref = padded_output;
// We then need to fill a new "reverted" filter
@@ -976,11 +979,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_temp(DataTypeToEnum<T>::v(),
r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{0, 1, 3, 2};
+ Eigen::DSizes<int, 4> filter_order{0, 1, 3, 2};
Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 4>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 4>());
+ functor::ShuffleAndReverse<Device, T, 4, int>()(
+ context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
+ filter_order, filter_rev_dims, To32Bit(r_filter.tensor<T, 4>()));
const Tensor& r_filter_cref = r_filter;
// Now we can call conv_2d directly.
@@ -1039,20 +1042,22 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->allocate_output(0, filter_shape, &filter_backprop));
const int padding_rows =
- (output_rows - 1) * stride + filter_rows - input_rows;
+ (padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows -
+ input_rows;
const int padding_cols =
- (output_cols - 1) * stride + filter_cols - input_cols;
+ (padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols -
+ input_cols;
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
// supporting different padding.
- bool padding_compatible =
- (padding_rows % 2 == 0) && (padding_cols % 2 == 0);
+ bool rows_odd = (padding_rows % 2 != 0);
+ bool cols_odd = (padding_cols % 2 != 0);
auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (use_cudnn_ && padding_compatible) {
+ if (use_cudnn_) {
if (filter_rows == 1 && filter_cols == 1 && stride == 1) {
const uint64 m = in_depth;
const uint64 k = batch * input_rows * input_cols;
@@ -1089,10 +1094,31 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
return;
}
+ Tensor compatible_input;
+ if (rows_odd || cols_odd) {
+ // If a padding dimension is odd, we have one more element on the right
+ // side or the bottom side. This is unsupported in cudnn. Therefore,
+ // we pad that extra element and make it compatible.
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({input.dim_size(0), input.dim_size(1) + rows_odd,
+ input.dim_size(2) + cols_odd, input.dim_size(3)}),
+ &compatible_input));
+
+ functor::PadInput<GPUDevice, T, int>()(
+ context->template eigen_device<GPUDevice>(),
+ To32Bit(input.tensor<T, 4>()), 0, rows_odd, 0, cols_odd,
+ To32Bit(compatible_input.tensor<T, 4>()));
+ } else {
+ compatible_input = input;
+ }
+
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(batch)
- .set_height(input_rows)
- .set_width(input_cols)
+ .set_height(compatible_input.dim_size(1))
+ .set_width(compatible_input.dim_size(2))
.set_feature_map_count(in_depth)
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::BatchDescriptor output_desc;
@@ -1146,14 +1172,19 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
transformed_out_backprop.tensor<T, 4>());
Tensor transformed_input;
- OP_REQUIRES_OK(context,
- context->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({batch, in_depth, input_rows, input_cols}),
- &transformed_input));
- functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
- input.tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({
+ compatible_input.dim_size(0), compatible_input.dim_size(3),
+ compatible_input.dim_size(1), compatible_input.dim_size(2),
+ }),
+ &transformed_input));
+ functor::NHWCToNCHW<Device, T>()(
+ context->eigen_device<Device>(),
+ const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
auto out_backprop_ptr =
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
@@ -1193,7 +1224,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// [batch, out_rows, out_cols, out_depth]
// And we need to change it to
// [out_depth, out_rows, out_cols, batch]
- Eigen::DSizes<Eigen::DenseIndex, 4> out_order{3, 1, 2, 0};
+ Eigen::DSizes<int, 4> out_order{3, 1, 2, 0};
TensorShape padded_out_shape(
{out_depth, padded_out_rows, padded_out_cols, batch});
Tensor padded_output;
@@ -1201,14 +1232,14 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->allocate_temp(DataTypeToEnum<T>::v(),
padded_out_shape, &padded_output));
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{
+ Eigen::array<Eigen::IndexPair<int>, 4> pad_dims{
{{0, 0},
{top_pad_rows, bottom_pad_rows},
{left_pad_cols, right_pad_cols},
{0, 0}}};
- functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
- pad_dims, out_order, padded_output.tensor<T, 4>());
+ functor::InflatePadAndShuffle<Device, T, 4, int>()(
+ context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
+ strides, pad_dims, out_order, To32Bit(padded_output.tensor<T, 4>()));
const Tensor& padded_output_cref = padded_output;
// For the backprop of the filter, we need to transpose the input.
@@ -1216,7 +1247,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// [batch, in_rows, in_cols, in_depth]
// And we need to change it to
// [in_rows, in_cols, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 4> in_order{1, 2, 0, 3};
+ Eigen::DSizes<int, 4> in_order{1, 2, 0, 3};
TensorShape in_shuffle_shape({input_rows, input_cols, batch, in_depth});
Tensor in_shuffle;
OP_REQUIRES_OK(context,
@@ -1225,9 +1256,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// No need for reversing this time.
Eigen::array<bool, 4> trivial_dims{false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 4>(), in_order,
- trivial_dims, in_shuffle.tensor<T, 4>());
+ functor::ShuffleAndReverse<Device, T, 4, int>()(
+ context->eigen_device<Device>(), To32Bit(input.tensor<T, 4>()),
+ in_order, trivial_dims, To32Bit(in_shuffle.tensor<T, 4>()));
const Tensor& in_shuffle_cref = in_shuffle;
// The output of the conv_2d would be
@@ -1250,12 +1281,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
BrainPadding2EigenPadding(VALID));
// Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{1, 2, 3, 0};
+ Eigen::DSizes<int, 4> filter_order{1, 2, 3, 0};
Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 4>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 4>());
+ functor::ShuffleAndReverse<Device, T, 4, int>()(
+ context->eigen_device<Device>(),
+ To32Bit(filter_shuffle_cref.tensor<T, 4>()), filter_order,
+ filter_rev_dims, To32Bit(filter_backprop->tensor<T, 4>()));
}
}
@@ -1271,25 +1303,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
- void ShuffleAndReverse<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \
- const GPUDevice& d, \
- typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \
- const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
- const Eigen::array<bool, 4>& reverse_dims, \
- typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \
- extern template struct ShuffleAndReverse<GPUDevice, T, 4, \
- Eigen::DenseIndex>; \
- template <> \
- void InflatePadAndShuffle<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \
- const GPUDevice& d, \
- typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \
- const Eigen::DSizes<Eigen::DenseIndex, 4>& strides, \
- const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4>& pad_dims, \
- const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
- typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \
- extern template struct InflatePadAndShuffle<GPUDevice, T, 4, \
- Eigen::DenseIndex>; \
- template <> \
void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
const Eigen::DSizes<int, 4>& order, \
@@ -1328,7 +1341,13 @@ namespace functor {
typename TTypes<T, 4>::ConstTensor filter, \
typename TTypes<T, 4>::ConstTensor output_backprop, int input_rows, \
int input_cols, int stride); \
- extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>
+ extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>; \
+ template <> \
+ void PadInput<GPUDevice, T, int>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ int padding_rows_left, int padding_rows_right, int padding_cols_left, \
+ int padding_cols_right, typename TTypes<T, 4, int>::Tensor out); \
+ extern template struct PadInput<GPUDevice, T, int>;
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
diff --git a/tensorflow/core/kernels/conv_ops_gpu.cu.cc b/tensorflow/core/kernels/conv_ops_gpu.cu.cc
index 60ff6b0024..e4ee058406 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu.cu.cc
@@ -33,12 +33,8 @@ struct SpatialConvolution<GPUDevice, T> {
typename TTypes<T, 4>::ConstTensor input,
typename TTypes<T, 4>::ConstTensor filter, int stride,
const Eigen::PaddingType& padding) {
- // TODO(keveman): nvcc 6.5 crashes when 32 bit indexing is turned on. Enable
- // this when we move to cuda 7.0.
- // SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input),
- // To32Bit(filter), stride, padding);
-
- SpatialConvolutionFunc(d, output, input, filter, stride, padding);
+ SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), To32Bit(filter),
+ stride, padding);
}
};
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index bc2b62375f..8fed594b25 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -16,21 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER5(BinaryOp, CPU, "Div", functor::div, float, double, int32, int64,
- complex64);
+REGISTER7(BinaryOp, CPU, "Div", functor::div, float, double, uint8, int16,
+ int32, int64, complex64);
#if GOOGLE_CUDA
-REGISTER3(BinaryOp, GPU, "Div", functor::div, float, double, int64);
+REGISTER6(BinaryOp, GPU, "Div", functor::div, float, double, uint8, int16,
+ int32, int64);
#endif
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Div")
- .Device(DEVICE_GPU)
- .HostMemory("x")
- .HostMemory("y")
- .HostMemory("z")
- .TypeConstraint<int32>("T"),
- BinaryOp<CPUDevice, functor::div<int32>>);
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
index 80a02da651..a2809d5481 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_BINARY3(div, float, double, int64);
+DEFINE_BINARY6(div, float, double, uint8, int16, int32, int64);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc
index a4ecaf185a..068003b294 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-DEFINE_BINARY3(mul, float, double, int64);
+DEFINE_BINARY7(mul, float, double, uint8, int8, int16, int32, int64);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc
index a7b9859b19..42d50358e6 100644
--- a/tensorflow/core/kernels/cwise_op_mul.cc
+++ b/tensorflow/core/kernels/cwise_op_mul.cc
@@ -16,21 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER7(BinaryOp, CPU, "Mul", functor::mul, float, double, int32, int64, int8,
- int16, complex64);
+REGISTER8(BinaryOp, CPU, "Mul", functor::mul, float, double, uint8, int8, int16,
+ int32, int64, complex64);
#if GOOGLE_CUDA
-REGISTER3(BinaryOp, GPU, "Mul", functor::mul, float, double, int64);
+REGISTER7(BinaryOp, GPU, "Mul", functor::mul, float, double, uint8, int8, int16,
+ int32, int64);
#endif
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Mul")
- .Device(DEVICE_GPU)
- .HostMemory("x")
- .HostMemory("y")
- .HostMemory("z")
- .TypeConstraint<int32>("T"),
- BinaryOp<CPUDevice, functor::mul<int32>>);
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 3296826d48..adf4203322 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -379,6 +379,8 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
REGISTER(OP, D, N, F, T0)
+#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
+ REGISTER(OP, D, N, F, T0)
#else // !defined(__ANDROID__)
#define REGISTER2(OP, D, N, F, T0, T1) \
REGISTER(OP, D, N, F, T0) \
@@ -398,6 +400,9 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
REGISTER3(OP, D, N, F, T4, T5, T6)
+#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
+ REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
+ REGISTER4(OP, D, N, F, T4, T5, T6, T7)
#endif // defined(__ANDROID__)
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
index 966d3393b6..091c6717dc 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -40,7 +40,7 @@ template <typename Functor>
struct UnaryFunctor<GPUDevice, Functor> {
void operator()(const GPUDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in) {
- out.device(d) = in.unaryExpr(typename Functor::func());
+ To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func());
}
};
@@ -50,7 +50,8 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
void operator()(const GPUDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in0,
typename Functor::tin_type in1) {
- out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+ To32Bit(out).device(d) =
+ To32Bit(in0).binaryExpr(in1, typename Functor::func());
}
void Left(const GPUDevice& d, typename Functor::tout_type out,
@@ -60,7 +61,7 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
typedef typename Functor::in_type Tin;
typedef typename Functor::func Binary;
typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
- out.device(d) = in.unaryExpr(Unary(scalar.data()));
+ To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data()));
}
void Right(const GPUDevice& d, typename Functor::tout_type out,
@@ -70,7 +71,7 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
typedef typename Functor::in_type Tin;
typedef typename Functor::func Binary;
typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
- out.device(d) = in.unaryExpr(Unary(scalar.data()));
+ To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data()));
}
void BCast(const GPUDevice& d,
@@ -86,16 +87,18 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
if (bcast0_all_one && !bcast1_all_one) {
- out.device(d) = in0.binaryExpr(in1.broadcast(bcast1), func);
+ To32Bit(out).device(d) =
+ To32Bit(in0).binaryExpr(To32Bit(in1).broadcast(bcast1), func);
return;
}
if (!bcast0_all_one && bcast1_all_one) {
- out.device(d) = in0.broadcast(bcast0).binaryExpr(in1, func);
+ To32Bit(out).device(d) =
+ To32Bit(in0).broadcast(bcast0).binaryExpr(To32Bit(in1), func);
return;
}
}
- out.device(d) =
- in0.broadcast(bcast0).binaryExpr(in1.broadcast(bcast1), func);
+ To32Bit(out).device(d) = To32Bit(in0).broadcast(bcast0).binaryExpr(
+ To32Bit(in1).broadcast(bcast1), func);
}
};
@@ -105,7 +108,8 @@ struct SelectFunctor<GPUDevice, T> {
typename TTypes<bool>::ConstFlat cond_flat,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
- out.device(d) = cond_flat.select(then_flat, else_flat);
+ To32Bit(out).device(d) =
+ To32Bit(cond_flat).select(To32Bit(then_flat), To32Bit(else_flat));
}
};
@@ -143,6 +147,12 @@ struct SelectFunctor<GPUDevice, T> {
#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \
DEFINE_BINARY2(F, T0, T1); \
DEFINE_BINARY3(F, T2, T3, T4)
+#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \
+ DEFINE_BINARY3(F, T0, T1, T2); \
+ DEFINE_BINARY3(F, T3, T4, T5)
+#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \
+ DEFINE_BINARY3(F, T0, T1, T2); \
+ DEFINE_BINARY4(F, T3, T4, T5, T6)
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index fb779f2466..9ae2eedb30 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -30,10 +30,17 @@ limitations under the License.
namespace tensorflow {
+namespace {
+
+// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the
+// main band matrix approach used below. Benchmarks suggest switching to
+// MognetLRN when depth > 384.
+const int kMognetLRNDepthCutoff = 384;
+
// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
// depth_radius + 1) around the diagonal.
-static void GetBandMatrix(int depth, int64 depth_radius,
- Eigen::Tensor<float, 2, Eigen::RowMajor>* result) {
+void GetBandMatrix(int depth, int64 depth_radius,
+ Eigen::Tensor<float, 2, Eigen::RowMajor>* result) {
result->setZero();
for (int row = 0; row < depth; ++row) {
const int begin = std::max<int>(0, row - depth_radius);
@@ -44,6 +51,8 @@ static void GetBandMatrix(int depth, int64 depth_radius,
}
}
+} // namespace
+
class LRNOp : public OpKernel {
public:
explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -69,6 +78,11 @@ class LRNOp : public OpKernel {
#if defined(__ANDROID__)
MognetLRN(in, batch, rows, cols, depth, output);
#else
+ if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) {
+ MognetLRN(in, batch, rows, cols, depth, output);
+ return;
+ }
+
const int nodes = cols * rows;
auto in_shaped = in.shaped<float, 2>({nodes * batch, depth});
@@ -79,13 +93,16 @@ class LRNOp : public OpKernel {
auto out_shaped = output->shaped<float, 2>({nodes * batch, depth});
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
- /// TODO(keveman): Optimize for beta in {0, 1, 0.5}
- out_shaped.device(context->eigen_cpu_device()) =
- in_shaped /
- in_shaped.square()
- .contract(multiplier, dims)
- .unaryExpr([this](float x) { return bias_ + alpha_ * x; })
- .pow(beta_);
+ auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
+ if (beta_ == 1.0f) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.inverse();
+ } else if (beta_ == 0.5f) {
+ out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
+ } else {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * (tmp.log() * -beta_).exp();
+ }
#endif
}
@@ -104,11 +121,11 @@ class LRNOp : public OpKernel {
Eigen::VectorXf padded_square(data_in.rows() + double_depth_radius);
padded_square.setZero();
for (int r = 0; r < data_in.cols(); ++r) {
- // Do local response normalization for data_in(:, r)
- // first, compute the square and store them in buffer for repeated use
+ // Do local response normalization for data_in(:, r). First, compute the
+ // square and store them in buffer for repeated use.
padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
- // Then, compute the scale and writes them to data_out
+ // Then, compute the scale and write it to data_out.
float accumulated_scale = 0;
for (int i = 0; i < double_depth_radius; ++i) {
accumulated_scale += padded_square(i);
@@ -120,13 +137,13 @@ class LRNOp : public OpKernel {
}
}
- // In a few cases, the pow computation could benefit from speedups.
if (beta_ == 1) {
data_out.array() = data_in.array() * data_out.array().inverse();
} else if (beta_ == 0.5) {
- data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
+ data_out.array() = data_in.array() * data_out.array().rsqrt();
} else {
- data_out.array() = data_in.array() * data_out.array().pow(-beta_);
+ data_out.array() =
+ data_in.array() * (data_out.array().log() * -beta_).exp();
}
}
diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h
deleted file mode 100644
index 16fa541238..0000000000
--- a/tensorflow/core/kernels/reference_gemm.h
+++ /dev/null
@@ -1,90 +0,0 @@
-/* Copyright 2015 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
-#define TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
-
-// This is an unoptimized but debuggable implementation of the GEMM matrix
-// multiply function, used to compare to faster but more opaque versions, or
-// for bit depths or argument combinations that aren't supported by optimized
-// code.
-// It assumes the row-major convention used by TensorFlow, and implements
-// C = A * B, like the standard BLAS GEMM interface. If the tranpose flags are
-// true, then the relevant matrix is treated as stored in column-major order.
-
-namespace tensorflow {
-template <class T1, class T2, class T3>
-void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c,
- size_t m, size_t n, size_t k, const T1* a, T1 offset_a,
- size_t lda, const T2* b, T2 offset_b, size_t ldb, T3* c,
- int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) {
- int a_i_stride;
- int a_l_stride;
- if (transpose_a) {
- a_i_stride = 1;
- a_l_stride = lda;
- } else {
- a_i_stride = lda;
- a_l_stride = 1;
- }
- int b_j_stride;
- int b_l_stride;
- if (transpose_b) {
- b_j_stride = ldb;
- b_l_stride = 1;
- } else {
- b_j_stride = 1;
- b_l_stride = ldb;
- }
- int c_i_stride;
- int c_j_stride;
- if (transpose_c) {
- c_i_stride = 1;
- c_j_stride = ldc;
- } else {
- c_i_stride = ldc;
- c_j_stride = 1;
- }
-
- const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
- const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
- const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1));
-
- int i, j, l;
- for (j = 0; j < n; j++) {
- for (i = 0; i < m; i++) {
- int32 total = 0;
- for (l = 0; l < k; l++) {
- const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
- const int32 a_value = a[a_index] - offset_a;
- const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
- const int32 b_value = b[b_index] - offset_b;
- total += (a_value * b_value);
- }
- const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
- int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c);
- if (output > highest) {
- output = highest;
- }
- if (output < lowest) {
- output = lowest;
- }
- c[c_index] = static_cast<T3>(output);
- }
- }
-}
-} // namespace tensorflow
-
-#endif // TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 0671414c51..a25c68a15a 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -39,7 +39,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device>
-void CheckErrors(OpKernelContext* context, int seq_dim) {
+void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1);
@@ -52,15 +52,18 @@ void CheckErrors(OpKernelContext* context, int seq_dim) {
seq_lens_vec.data(), seq_lens_t.data(),
sizeof(int64) * seq_lens_t.size());
- OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
+ OP_REQUIRES(context, batch_dim != seq_dim,
+ errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
OP_REQUIRES(context, seq_dim < input.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
seq_dim, " vs. ", input.dims(), ")"));
-
- OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
- errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
- "(", seq_lens.NumElements(), " vs. ",
- input.dim_size(seq_dim)));
+ OP_REQUIRES(context, batch_dim < input.dims(),
+ errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
+ batch_dim, " vs. ", input.dims(), ")"));
+ OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
+ errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
+ "), ", "(", seq_lens.NumElements(),
+ " vs. ", input.dim_size(batch_dim)));
for (int d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -72,19 +75,24 @@ void CheckErrors(OpKernelContext* context, int seq_dim) {
}
template <>
-void CheckErrors<GPUDevice>(OpKernelContext* context, int seq_dim) {
+void CheckErrors<GPUDevice>(OpKernelContext* context, int batch_dim,
+ int seq_dim) {
const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1);
- OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
+ OP_REQUIRES(context, batch_dim != seq_dim,
+ errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
OP_REQUIRES(context, seq_dim < input.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
seq_dim, " vs. ", input.dims(), ")"));
-
- OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
- errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
- "(", seq_lens.NumElements(), " vs. ",
- input.dim_size(seq_dim)));
+ OP_REQUIRES(context, batch_dim < input.dims(),
+ errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
+ batch_dim, " vs. ", input.dims(), ")"));
+
+ OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
+ errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
+ "), ", "(", seq_lens.NumElements(),
+ " vs. ", input.dim_size(batch_dim)));
}
template <typename Device, typename T>
@@ -92,6 +100,7 @@ class ReverseSequenceOp : public OpKernel {
public:
explicit ReverseSequenceOp(OpKernelConstruction* context)
: OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
}
@@ -106,7 +115,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<int64>();
- CheckErrors<Device>(context, seq_dim_);
+ CheckErrors<Device>(context, batch_dim_, seq_dim_);
const int input_dims = input.dims();
@@ -114,11 +123,11 @@ class ReverseSequenceOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
-#define HANDLE_DIM(NDIM) \
- case NDIM: \
- functor::ReverseSequence<Device, T, NDIM>::Compute( \
- context->eigen_device<Device>(), input.tensor<T, NDIM>(), seq_dim_, \
- seq_lens_t, output->tensor<T, NDIM>()); \
+#define HANDLE_DIM(NDIM) \
+ case NDIM: \
+ functor::ReverseSequence<Device, T, NDIM>::Compute( \
+ context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
+ seq_dim_, seq_lens_t, output->tensor<T, NDIM>()); \
break;
switch (input_dims) {
@@ -136,6 +145,7 @@ class ReverseSequenceOp : public OpKernel {
}
private:
+ int32 batch_dim_;
int32 seq_dim_;
TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
@@ -152,12 +162,12 @@ TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE);
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T, Dims) \
- template <> \
- void ReverseSequence<GPUDevice, T, Dims>::Compute( \
- const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
- int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \
- typename TTypes<T, Dims>::Tensor output); \
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void ReverseSequence<GPUDevice, T, Dims>::Compute( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ int32 batch_dim, int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \
+ typename TTypes<T, Dims>::Tensor output); \
extern template struct ReverseSequence<GPUDevice, T, Dims>;
#define DECLARE_GPU_SPECS(T) \
diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h
index ceb1b0b880..9dd1e4d01d 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.h
+++ b/tensorflow/core/kernels/reverse_sequence_op.h
@@ -29,15 +29,19 @@ template <typename T, size_t Dims>
class ReverseGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
- ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 seq_dim,
- TTypes<int64>::ConstVec seq_lengths)
- : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {}
+ ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim,
+ int32 seq_dim, TTypes<int64>::ConstVec seq_lengths)
+ : input_(input),
+ batch_dim_(batch_dim),
+ seq_dim_(seq_dim),
+ seq_lengths_(seq_lengths) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<Eigen::DenseIndex, Dims>& coords) const {
Eigen::array<Eigen::DenseIndex, Dims> new_coords = coords;
- if (coords[seq_dim_] < seq_lengths_(coords[0])) {
- new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1;
+ if (coords[seq_dim_] < seq_lengths_(coords[batch_dim_])) {
+ new_coords[seq_dim_] =
+ seq_lengths_(coords[batch_dim_]) - coords[seq_dim_] - 1;
}
return input_(new_coords);
@@ -45,6 +49,7 @@ class ReverseGenerator {
private:
typename TTypes<T, Dims>::ConstTensor input_;
+ int32 batch_dim_;
int32 seq_dim_;
TTypes<int64>::ConstVec seq_lengths_;
};
@@ -57,9 +62,10 @@ template <typename Device, typename T, size_t Dims>
struct ReverseSequence {
EIGEN_ALWAYS_INLINE static void Compute(
const Device& d, typename TTypes<T, Dims>::ConstTensor input,
- int32 seq_dim, TTypes<int64>::ConstVec seq_lengths,
+ int32 batch_dim, int32 seq_dim, TTypes<int64>::ConstVec seq_lengths,
typename TTypes<T, Dims>::Tensor output) {
- generator::ReverseGenerator<T, Dims> generator(input, seq_dim, seq_lengths);
+ generator::ReverseGenerator<T, Dims> generator(input, batch_dim, seq_dim,
+ seq_lengths);
output.device(d) = input.generate(generator);
}
};
diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc
new file mode 100644
index 0000000000..e3480e3594
--- /dev/null
+++ b/tensorflow/core/kernels/softsign_op.cc
@@ -0,0 +1,112 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/softsign_op.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, SoftsignOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Softsign<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class SoftsignGradOp
+ : public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>::BinaryElementWiseOp;
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): inputs that were passed to SoftsignOp()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ functor::SoftsignGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Softsign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SoftsignOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SoftsignGradOp<CPUDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void Softsign<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct Softsign<GPUDevice, T>; \
+ \
+ template <> \
+ void SoftsignGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct SoftsignGrad<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Softsign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SoftsignOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SoftsignGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SoftsignGradOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/softsign_op.h b/tensorflow/core/kernels/softsign_op.h
new file mode 100644
index 0000000000..36790a5874
--- /dev/null
+++ b/tensorflow/core/kernels/softsign_op.h
@@ -0,0 +1,60 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
+#define TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
+// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by
+// nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by SoftsignOp to do the computations.
+template <typename Device, typename T>
+struct Softsign {
+ // Computes Softsign activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ activations.device(d) =
+ features / (features.abs() + features.constant(1.0f));
+ }
+};
+
+// Functor used by SoftsignGradOp to do the computations.
+template <typename Device, typename T>
+struct SoftsignGrad {
+ // Computes SoftsignGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Softsign op.
+ // features: inputs that were passed to the Softsign op.
+ // backprops: gradients to backpropagate to the Softsign inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ gradients / (features.abs() + features.constant(1.0f)).square();
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
diff --git a/tensorflow/core/kernels/softsign_op_gpu.cu.cc b/tensorflow/core/kernels/softsign_op_gpu.cu.cc
new file mode 100644
index 0000000000..4ae941c9f0
--- /dev/null
+++ b/tensorflow/core/kernels/softsign_op_gpu.cu.cc
@@ -0,0 +1,40 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include "tensorflow/core/kernels/softsign_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Definition of the GPU implementations declared in softsign_op.cc.
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Softsign<GPUDevice, T>; \
+ template struct functor::SoftsignGrad<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc
index c79410b68c..13463b705b 100644
--- a/tensorflow/core/kernels/split_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_op_gpu.cu.cc
@@ -33,7 +33,7 @@ void Split<Device, T>::operator()(
typename TTypes<T, 3>::ConstTensor input,
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
- output.device(d) = input.slice(slice_indices, slice_sizes);
+ To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes);
}
#define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>;
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 055050cd34..2c146b3d6c 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -1,3 +1,18 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
// See docs in ../ops/data_flow_ops.cc.
#include <limits.h>