aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/register_types.h4
-rw-r--r--tensorflow/core/kernels/BUILD34
-rw-r--r--tensorflow/core/kernels/dense_update_functor.cc73
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h (renamed from tensorflow/core/kernels/dense_update_ops.h)33
-rw-r--r--tensorflow/core/kernels/dense_update_functor_gpu.cu.cc (renamed from tensorflow/core/kernels/dense_update_ops_gpu.cu.cc)30
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc74
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc27
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc12
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc2
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h2
10 files changed, 164 insertions, 127 deletions
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index b62fe647e2..3f91642064 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -172,7 +172,9 @@ limitations under the License.
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
// Call "m" on all types supported on GPU.
-#define TF_CALL_GPU_ALL_TYPES(m) TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_bool(m)
+#define TF_CALL_GPU_ALL_TYPES(m) \
+ TF_CALL_GPU_NUMBER_TYPES(m) \
+ TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 6eee2e16eb..93d1df0f7b 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -19,6 +19,7 @@ package_group(
name = "friends",
packages = [
"//learning/brain/contrib/...",
+ "//learning/brain/research/sparse_matrix/...",
"//tensorflow/...",
],
)
@@ -94,13 +95,11 @@ tf_kernel_library(
"strided_slice_op_inst_7.cc",
],
hdrs = [
- "dense_update_ops.h",
"slice_op.h",
"strided_slice_op.h",
"strided_slice_op_impl.h",
],
gpu_srcs = [
- "dense_update_ops.h",
"slice_op.h",
"strided_slice_op.h",
"strided_slice_op_impl.h",
@@ -109,6 +108,7 @@ tf_kernel_library(
],
deps = [
":bounds_check",
+ ":dense_update_functor",
":ops_util",
":variable_ops",
"//tensorflow/core:framework",
@@ -1011,6 +1011,23 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "dense_update_functor",
+ srcs = ["dense_update_functor.cc"],
+ hdrs = ["dense_update_functor.h"],
+ gpu_srcs = [
+ "dense_update_functor.h",
+ "dense_update_functor_gpu.cu.cc",
+ ],
+ visibility = [":friends"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 0,
+)
+
tf_cuda_cc_test(
name = "gather_op_test",
size = "small",
@@ -1606,7 +1623,7 @@ tf_kernel_library(
srcs = ["resource_variable_ops.cc"],
deps = [
":bounds_check",
- ":dense_update_ops",
+ ":dense_update_functor",
":gather_functor",
":scatter_functor",
":state",
@@ -2079,7 +2096,7 @@ tf_kernel_library(
"//tensorflow:darwin": [],
"//conditions:default": ["-Wl,-z,muldefs"],
}),
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -3482,7 +3499,7 @@ tf_kernel_library(
tf_kernel_library(
name = "dense_update_ops",
prefix = "dense_update_ops",
- deps = STATE_DEPS,
+ deps = STATE_DEPS + [":dense_update_functor"],
)
tf_kernel_library(
@@ -3503,16 +3520,14 @@ tf_kernel_library(
"scatter_nd_op_cpu_impl_5.cc",
],
hdrs = [
- "dense_update_ops.h",
"scatter_nd_op.h",
"scatter_nd_op_cpu_impl.h",
],
gpu_srcs = [
- "dense_update_ops.h",
"scatter_nd_op.h",
"scatter_nd_op_gpu.cu.cc",
],
- deps = STATE_DEPS + [":dense_update_ops"],
+ deps = STATE_DEPS + [":dense_update_functor"],
)
tf_kernel_library(
@@ -4054,8 +4069,9 @@ filegroup(
"cwise_ops_common.cc",
"cwise_ops_common.h",
"cwise_ops_gradients.h",
+ "dense_update_functor.cc",
+ "dense_update_functor.h",
"dense_update_ops.cc",
- "dense_update_ops.h",
"example_parsing_ops.cc",
"fill_functor.cc",
"fill_functor.h",
diff --git a/tensorflow/core/kernels/dense_update_functor.cc b/tensorflow/core/kernels/dense_update_functor.cc
new file mode 100644
index 0000000000..a878fe9a97
--- /dev/null
+++ b/tensorflow/core/kernels/dense_update_functor.cc
@@ -0,0 +1,73 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/dense_update_functor.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <>
+struct DenseUpdate<CPUDevice, string, ASSIGN> {
+ void operator()(const CPUDevice& d, typename TTypes<string>::Flat params,
+ typename TTypes<string>::ConstFlat update) {
+ if (params.dimension(0) == 1) {
+ params.data()->resize(update.data()->size());
+ auto work = [&params, &update](int64 start, int64 end) {
+ memmove(const_cast<char*>(params.data()->data()) + start,
+ update.data()->data() + start, end - start);
+ };
+ d.parallelFor(update.data()->size(),
+ Eigen::TensorOpCost(.1, // chosen to force large chunks
+ .1, 0),
+ work);
+ } else {
+ auto work = [&params, &update](int64 start, int64 end) {
+ for (int i = start; i < end; ++i) {
+ params.data()[i].resize(update.data()[i].size());
+ memmove(const_cast<char*>(params.data()[i].data()),
+ update.data()[i].data(), update.data()[i].size());
+ }
+ };
+ int64 estimated_string_size;
+ if (update.size() > 0) {
+ // first element of the tensor seems as good a guess as any of the sizes
+ // of the strings contained within...
+ estimated_string_size =
+ std::max(update.data()[0].size(), sizeof(string));
+ } else {
+ estimated_string_size = sizeof(string);
+ }
+ d.parallelFor(
+ params.dimension(0),
+ Eigen::TensorOpCost(estimated_string_size, estimated_string_size, 0),
+ work);
+ }
+ }
+};
+
+} // namespace functor
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_ops.h b/tensorflow/core/kernels/dense_update_functor.h
index ec7e9dff11..54b080c83b 100644
--- a/tensorflow/core/kernels/dense_update_ops.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -13,40 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
-#define TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
+#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#define TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+
+#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
enum DenseUpdateType { ADD, SUB, ASSIGN };
namespace functor {
template <typename Device, typename T, DenseUpdateType OP>
-struct DenseUpdate;
-
-template <typename Device, typename T>
-struct DenseUpdate<Device, T, ADD> {
+struct DenseUpdate {
void operator()(const Device& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update);
+};
+
+template <typename T>
+struct DenseUpdate<CPUDevice, T, ADD> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) += update;
}
};
-template <typename Device, typename T>
-struct DenseUpdate<Device, T, SUB> {
- void operator()(const Device& d, typename TTypes<T>::Flat params,
+template <typename T>
+struct DenseUpdate<CPUDevice, T, SUB> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) -= update;
}
};
-template <typename Device, typename T>
-struct DenseUpdate<Device, T, ASSIGN> {
- void operator()(const Device& d, typename TTypes<T>::Flat params,
+template <typename T>
+struct DenseUpdate<CPUDevice, T, ASSIGN> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) = update;
}
@@ -55,4 +62,4 @@ struct DenseUpdate<Device, T, ASSIGN> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
+#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc
index cf2d493061..208401cb24 100644
--- a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/framework/register_types.h"
@@ -25,6 +25,34 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+namespace functor {
+
+template <typename T>
+struct DenseUpdate<GPUDevice, T, ASSIGN> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) = update;
+ }
+};
+
+template <typename T>
+struct DenseUpdate<GPUDevice, T, ADD> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) += update;
+ }
+};
+
+template <typename T>
+struct DenseUpdate<GPUDevice, T, SUB> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) -= update;
+ }
+};
+
+} // namespace functor
+
#define DEFINE_GPU_KERNELS(T) \
template struct functor::DenseUpdate<GPUDevice, T, ADD>; \
template struct functor::DenseUpdate<GPUDevice, T, SUB>;
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
index af16c55a5d..6d44a92fa3 100644
--- a/tensorflow/core/kernels/dense_update_ops.cc
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -15,59 +15,20 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/assign_op.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-namespace functor {
-
-template <>
-struct DenseUpdate<Eigen::ThreadPoolDevice, string, ASSIGN> {
- void operator()(const Eigen::ThreadPoolDevice& d,
- typename TTypes<string>::Flat params,
- typename TTypes<string>::ConstFlat update) {
- if (params.dimension(0) == 1) {
- params.data()->resize(update.data()->size());
- auto work = [&params, &update](int64 start, int64 end) {
- memmove(const_cast<char*>(params.data()->data()) + start,
- update.data()->data() + start, end - start);
- };
- d.parallelFor(update.data()->size(),
- Eigen::TensorOpCost(.1, // chosen to force large chunks
- .1, 0),
- work);
- } else {
- auto work = [&params, &update](int64 start, int64 end) {
- for (int i = start; i < end; ++i) {
- params.data()[i].resize(update.data()[i].size());
- memmove(const_cast<char*>(params.data()[i].data()),
- update.data()[i].data(), update.data()[i].size());
- }
- };
- int64 estimated_string_size;
- if (update.size() > 0) {
- // first element of the tensor seems as good a guess as any of the sizes
- // of the strings contained within...
- estimated_string_size =
- std::max(update.data()[0].size(), sizeof(string));
- } else {
- estimated_string_size = sizeof(string);
- }
- d.parallelFor(
- params.dimension(0),
- Eigen::TensorOpCost(estimated_string_size, estimated_string_size, 0),
- work);
- }
- }
-};
-
-} // namespace functor
template <typename Device, typename T>
class AssignOpT : public AssignOp {
@@ -117,7 +78,7 @@ class DenseUpdateOp : public OpKernel {
errors::InvalidArgument("Parameters and update must be the same size"));
functor::DenseUpdate<Device, T, OP> update_functor;
- update_functor(context->eigen_device<Device>(), Tparams.flat<T>(),
+ update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(),
Tupdate.flat<T>());
}
@@ -143,13 +104,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
// Only register 'Assign' on GPU for the subset of types also supported by
// 'Variable' (see variable_ops.cc.)
#define REGISTER_GPU_KERNELS(type) \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
- } \
REGISTER_KERNEL_BUILDER( \
Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
AssignOpT<GPUDevice, type>);
@@ -180,22 +134,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
-// Forward declarations of the functor specializations for GPU.
-namespace functor {
-#define DECLARE_GPU_SPEC_FOR_OP(T, OP) \
- template <> \
- void DenseUpdate<GPUDevice, T, OP>::operator()( \
- const GPUDevice& d, typename TTypes<T>::Flat params, \
- typename TTypes<T>::ConstFlat update); \
- extern template struct DenseUpdate<GPUDevice, T, OP>;
-#define DECLARE_GPU_SPEC(T) \
- DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::ADD); \
- DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::SUB)
-TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
-#undef DECLARE_GPU_SPEC
-#undef DECLARE_GPU_SPEC_FOR_OP
-} // namespace functor
-
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index f0fef256d2..0616bb5a08 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -15,12 +15,16 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
@@ -217,13 +221,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
- } \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype") \
@@ -275,20 +272,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, ADD>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ADD>; \
- } \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, SUB>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, SUB>; \
- } \
REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
.Device(DEVICE_GPU) \
.HostMemory("resource") \
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 1428546d52..59f690e7aa 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
@@ -523,16 +523,6 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_INDEX_OP
-#define REGISTER_GPU_KERNELS(type) \
- template <> \
- void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ASSIGN>;
-
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
-#undef REGISTER_GPU_KERNELS
-
} // namespace functor
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index fa518712ce..4655503e26 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#endif // GOOGLE_CUDA
#include "tensorflow/core/kernels/strided_slice_op.h"
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/slice_op.h"
#include "tensorflow/core/kernels/strided_slice_op_impl.h"
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index d0ccd5c652..de65147572 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types_traits.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"