aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cast_op_impl.h')
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h145
1 files changed, 145 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
new file mode 100644
index 0000000000..cb7cc81937
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -0,0 +1,145 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/cast_op.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename O, typename I>
+struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
+ void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o,
+ typename TTypes<I>::ConstFlat i) {
+ o.device(d) = i.template cast<O>();
+ }
+};
+
+} // namespace functor
+
+#define CURRY_TYPES3(FN, arg0, arg1) \
+ FN(arg0, arg1, bool); \
+ FN(arg0, arg1, uint8); \
+ FN(arg0, arg1, int8); \
+ FN(arg0, arg1, uint16); \
+ FN(arg0, arg1, int16); \
+ FN(arg0, arg1, int32); \
+ FN(arg0, arg1, int64); \
+ FN(arg0, arg1, Eigen::half); \
+ FN(arg0, arg1, float); \
+ FN(arg0, arg1, double); \
+ FN(arg0, arg1, std::complex<float>); \
+ FN(arg0, arg1, std::complex<double>)
+
+#define CAST_CASE(DEVICE, IN, OUT) \
+ if (DataTypeToEnum<OUT>::value == dst_dtype) { \
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
+ functor::CastFunctor<DEVICE, OUT, IN> func; \
+ func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
+ }; \
+ }
+
+// The functions below are implemented in the cast_op_impl_*.cc files.
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromBool(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromHalf(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromFloat(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromDouble(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromComplex64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromComplex128(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromBfloat(DataType dst_dtype);
+
+#if GOOGLE_CUDA
+// Same, for GPU.
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromBool(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromHalf(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromFloat(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromDouble(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromComplex64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromComplex128(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromBfloat(DataType dst_dtype);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_