diff options
Diffstat (limited to 'tensorflow/core/kernels/variable_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/variable_ops.cc | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc new file mode 100644 index 0000000000..2f1dbc68c0 --- /dev/null +++ b/tensorflow/core/kernels/variable_ops.cc @@ -0,0 +1,37 @@ +#define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/variable_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp); +REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU), + TemporaryVariableOp); +REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU), + DestroyTemporaryVariableOp); + +#if GOOGLE_CUDA +// Only register 'Variable' on GPU for the subset of types also supported by +// 'Assign' (see dense_update_ops.cc.) +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \ + VariableOp); \ + REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype"), \ + TemporaryVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T"), \ + DestroyTemporaryVariableOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS +#endif // GOOGLE_CUDA + +} // namespace tensorflow |