aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/variable_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/variable_ops.cc')
-rw-r--r--tensorflow/core/kernels/variable_ops.cc37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
new file mode 100644
index 0000000000..2f1dbc68c0
--- /dev/null
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -0,0 +1,37 @@
+#define EIGEN_USE_THREADS
+#include "tensorflow/core/kernels/variable_ops.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
+REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
+ TemporaryVariableOp);
+REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
+ DestroyTemporaryVariableOp);
+
+#if GOOGLE_CUDA
+// Only register 'Variable' on GPU for the subset of types also supported by
+// 'Assign' (see dense_update_ops.cc.)
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype"), \
+ TemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T"), \
+ DestroyTemporaryVariableOp);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow