aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/one_hot_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-05-24 17:02:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-24 18:17:20 -0700
commit1135dac30a92a022aced959d44fbca4372bf9a80 (patch)
tree1b0e3ea49350efde18b6528fe6259db482f841c5 /tensorflow/core/kernels/one_hot_op_gpu.cu.cc
parent9a7ccc7607a645b31f2bd213a23b001d779b7da7 (diff)
Merge changes from github.
Change: 123167405
Diffstat (limited to 'tensorflow/core/kernels/one_hot_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/one_hot_op_gpu.cu.cc13
1 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc
index 804f7ef406..70433b22ae 100644
--- a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc
@@ -21,17 +21,24 @@ limitations under the License.
#include "tensorflow/core/kernels/one_hot_op.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPEC(T) \
- template class generator::OneGenerator<T>; \
- template struct functor::OneHot<GPUDevice, T>;
+#define DEFINE_GPU_SPEC_INDEX(T, TI) \
+ template class generator::OneGenerator<T, TI>; \
+ template struct functor::OneHot<GPUDevice, T, TI>;
+
+#define DEFINE_GPU_SPEC(T) \
+ DEFINE_GPU_SPEC_INDEX(T, int32); \
+ DEFINE_GPU_SPEC_INDEX(T, int64)
+
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
+#undef DEFINE_GPU_SPEC_INDEX
#undef DEFINE_GPU_SPEC
} // end namespace tensorflow