diff options
author | 2017-07-12 15:24:40 -0700 | |
---|---|---|
committer | 2017-07-12 15:30:16 -0700 | |
commit | c65f691196f4532fd117eacc8e3dd225c4844d3f (patch) | |
tree | bfc25735ac333237374533009a4e127516d2871c /tensorflow/core/kernels/resource_variable_ops.cc | |
parent | 9a172989e0b3fa8de69220caf5279643cd8366d2 (diff) |
Factor out DenseUpdate ops into dense_update_functor build dep.
Also add support for complex types.
PiperOrigin-RevId: 161726749
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 27 |
1 files changed, 5 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index f0fef256d2..0616bb5a08 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -15,12 +15,16 @@ limitations under the License. #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/dense_update_ops.h" +#include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/kernels/scatter_functor.h" #include "tensorflow/core/kernels/variable_ops.h" @@ -217,13 +221,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ - namespace functor { \ - template <> \ - void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \ - const GPUDevice& d, typename TTypes<type>::Flat lhs, \ - typename TTypes<type>::ConstFlat rhs); \ - extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \ - } \ REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ .Device(DEVICE_GPU) \ .TypeConstraint<type>("dtype") \ @@ -275,20 +272,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ - namespace functor { \ - template <> \ - void DenseUpdate<GPUDevice, type, ADD>::operator()( \ - const GPUDevice& d, typename TTypes<type>::Flat lhs, \ - typename TTypes<type>::ConstFlat rhs); \ - extern template struct DenseUpdate<GPUDevice, type, ADD>; \ - } \ - namespace functor { \ - template <> \ - void DenseUpdate<GPUDevice, type, SUB>::operator()( \ - const GPUDevice& d, typename TTypes<type>::Flat lhs, \ - typename TTypes<type>::ConstFlat rhs); \ - extern template struct DenseUpdate<GPUDevice, type, SUB>; \ - } \ REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \ .Device(DEVICE_GPU) \ .HostMemory("resource") \ |