aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/resource_variable_ops.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-07-12 15:24:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-12 15:30:16 -0700
commitc65f691196f4532fd117eacc8e3dd225c4844d3f (patch)
treebfc25735ac333237374533009a4e127516d2871c /tensorflow/core/kernels/resource_variable_ops.cc
parent9a172989e0b3fa8de69220caf5279643cd8366d2 (diff)
Factor out DenseUpdate ops into dense_update_functor build dep.
Also add support for complex types. PiperOrigin-RevId: 161726749
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc27
1 files changed, 5 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index f0fef256d2..0616bb5a08 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -15,12 +15,16 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
@@ -217,13 +221,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
- } \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype") \
@@ -275,20 +272,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, ADD>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ADD>; \
- } \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, SUB>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, SUB>; \
- } \
REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
.Device(DEVICE_GPU) \
.HostMemory("resource") \