aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-15 13:05:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-15 14:17:19 -0700
commit6da951a4e0b8744b331a91ac2c398efcc4af6bed (patch)
tree70c76a1c2ea6085ef3137fd7866f3749a869b8ae /tensorflow/core
parent25236c29a246d7eebaf0be37fa38ad2782699b21 (diff)
Shard cast_op implementations for to improve build times.
Change: 130321753
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/BUILD14
-rw-r--r--tensorflow/core/kernels/cast_op.cc160
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h145
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bfloat.cc55
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bool.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex128.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex64.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_double.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_float.cc52
-rw-r--r--tensorflow/core/kernels/cast_op_impl_half.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int16.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int32.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int64.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int8.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint16.cc37
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint8.cc37
16 files changed, 738 insertions, 95 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a61c2d8783..7f8b2b439d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1830,6 +1830,20 @@ filegroup(
"bounds_check.h",
"cast_op.cc",
"cast_op.h",
+ "cast_op_impl.h",
+ "cast_op_impl_bfloat.cc",
+ "cast_op_impl_bool.cc",
+ "cast_op_impl_complex128.cc",
+ "cast_op_impl_complex64.cc",
+ "cast_op_impl_double.cc",
+ "cast_op_impl_float.cc",
+ "cast_op_impl_half.cc",
+ "cast_op_impl_int16.cc",
+ "cast_op_impl_int32.cc",
+ "cast_op_impl_int64.cc",
+ "cast_op_impl_int8.cc",
+ "cast_op_impl_uint16.cc",
+ "cast_op_impl_uint8.cc",
"concat_lib.h",
"concat_lib_cpu.cc",
"concat_lib_cpu.h",
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index c08995b464..ab82c247d6 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -28,23 +28,13 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-namespace functor {
-
-template <typename O, typename I>
-struct CastFunctor<CPUDevice, O, I> {
- void operator()(const CPUDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
-
-} // namespace functor
-
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
FN(arg0, uint8); \
@@ -59,30 +49,6 @@ struct CastFunctor<CPUDevice, O, I> {
FN(arg0, std::complex<float>); \
FN(arg0, std::complex<double>)
-#define CURRY_TYPES3(FN, arg0, arg1) \
- FN(arg0, arg1, bool); \
- FN(arg0, arg1, uint8); \
- FN(arg0, arg1, int8); \
- FN(arg0, arg1, uint16); \
- FN(arg0, arg1, int16); \
- FN(arg0, arg1, int32); \
- FN(arg0, arg1, int64); \
- FN(arg0, arg1, Eigen::half); \
- FN(arg0, arg1, float); \
- FN(arg0, arg1, double); \
- FN(arg0, arg1, std::complex<float>); \
- FN(arg0, arg1, std::complex<double>)
-
-#define CAST_CASE(DEVICE, IN, OUT) \
- if (DataTypeToEnum<IN>::value == src_dtype_ && \
- DataTypeToEnum<OUT>::value == dst_dtype_) { \
- work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
- functor::CastFunctor<DEVICE, OUT, IN> func; \
- func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
- }; \
- return Status::OK(); \
- }
-
class CastOpBase : public OpKernel {
public:
explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -106,7 +72,6 @@ class CastOpBase : public OpKernel {
DataType dst_dtype_;
std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
- virtual Status Prepare() = 0;
Status Unimplemented() {
return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
DataTypeString(dst_dtype_),
@@ -122,88 +87,94 @@ class CpuCastOp : public CastOpBase {
OP_REQUIRES_OK(ctx, Prepare());
}
- protected:
- Status Prepare() override {
+ private:
+ Status Prepare() {
if (src_dtype_ == dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
- CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
- CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
- CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
- CURRY_TYPES3(CAST_CASE, CPUDevice, uint16);
- CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
- CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
- CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
- CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half);
- CURRY_TYPES3(CAST_CASE, CPUDevice, float);
- CURRY_TYPES3(CAST_CASE, CPUDevice, double);
- CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>);
- CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>);
-
- if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) {
- work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
- int64 N = out->NumElements();
- auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
- auto work = [&inp, &out](int64 start, int64 end) {
- BFloat16ToFloat(inp.flat<bfloat16>().data() + start,
- out->flat<float>().data() + start, end - start);
- };
- Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work);
- };
- return Status::OK();
- }
- if (src_dtype_ == DT_FLOAT && dst_dtype_ == DT_BFLOAT16) {
- work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
- int64 N = out->NumElements();
- auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
- auto work = [&inp, &out](int64 start, int64 end) {
- FloatToBFloat16(inp.flat<float>().data() + start,
- out->flat<bfloat16>().data() + start, end - start);
- };
- Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work);
- };
- return Status::OK();
+ if (src_dtype_ == DT_BOOL) {
+ work_ = GetCpuCastFromBool(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT8) {
+ work_ = GetCpuCastFromUint8(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetCpuCastFromInt8(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT16) {
+ work_ = GetCpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_INT16) {
+ work_ = GetCpuCastFromInt16(dst_dtype_);
+ } else if (src_dtype_ == DT_INT32) {
+ work_ = GetCpuCastFromInt32(dst_dtype_);
+ } else if (src_dtype_ == DT_INT64) {
+ work_ = GetCpuCastFromInt64(dst_dtype_);
+ } else if (src_dtype_ == DT_HALF) {
+ work_ = GetCpuCastFromHalf(dst_dtype_);
+ } else if (src_dtype_ == DT_FLOAT) {
+ work_ = GetCpuCastFromFloat(dst_dtype_);
+ } else if (src_dtype_ == DT_DOUBLE) {
+ work_ = GetCpuCastFromDouble(dst_dtype_);
+ } else if (src_dtype_ == DT_COMPLEX64) {
+ work_ = GetCpuCastFromComplex64(dst_dtype_);
+ } else if (src_dtype_ == DT_COMPLEX128) {
+ work_ = GetCpuCastFromComplex128(dst_dtype_);
+ } else if (src_dtype_ == DT_BFLOAT16) {
+ work_ = GetCpuCastFromBfloat(dst_dtype_);
}
// TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
// bottleneck, we could probably implement specialized support for
// vectorized versions (not the least based on F16C for Haswell
- // or newer) here.
+ // or newer).
- return Unimplemented();
+ return work_ == nullptr ? Unimplemented() : Status::OK();
}
};
+#if GOOGLE_CUDA
class GpuCastOp : public CastOpBase {
public:
explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
OP_REQUIRES_OK(ctx, Prepare());
}
- protected:
- Status Prepare() override {
+ private:
+ Status Prepare() {
if (src_dtype_ == dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
- CURRY_TYPES3(CAST_CASE, GPUDevice, bool);
- CURRY_TYPES3(CAST_CASE, GPUDevice, uint8);
- CURRY_TYPES3(CAST_CASE, GPUDevice, int8);
- CURRY_TYPES3(CAST_CASE, GPUDevice, uint16);
- CURRY_TYPES3(CAST_CASE, GPUDevice, int16);
- CURRY_TYPES3(CAST_CASE, GPUDevice, int32);
- CURRY_TYPES3(CAST_CASE, GPUDevice, int64);
- CURRY_TYPES3(CAST_CASE, GPUDevice, Eigen::half);
- CURRY_TYPES3(CAST_CASE, GPUDevice, float);
- CURRY_TYPES3(CAST_CASE, GPUDevice, double);
- CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<float>);
- CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<double>);
- CAST_CASE(GPUDevice, float, bfloat16);
- CAST_CASE(GPUDevice, bfloat16, float);
- return Unimplemented();
+ if (src_dtype_ == DT_BOOL) {
+ work_ = GetGpuCastFromBool(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT8) {
+ work_ = GetGpuCastFromUint8(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetGpuCastFromInt8(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT16) {
+ work_ = GetGpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_INT16) {
+ work_ = GetGpuCastFromInt16(dst_dtype_);
+ } else if (src_dtype_ == DT_INT32) {
+ work_ = GetGpuCastFromInt32(dst_dtype_);
+ } else if (src_dtype_ == DT_INT64) {
+ work_ = GetGpuCastFromInt64(dst_dtype_);
+ } else if (src_dtype_ == DT_HALF) {
+ work_ = GetGpuCastFromHalf(dst_dtype_);
+ } else if (src_dtype_ == DT_FLOAT) {
+ work_ = GetGpuCastFromFloat(dst_dtype_);
+ } else if (src_dtype_ == DT_DOUBLE) {
+ work_ = GetGpuCastFromDouble(dst_dtype_);
+ } else if (src_dtype_ == DT_COMPLEX64) {
+ work_ = GetGpuCastFromComplex64(dst_dtype_);
+ } else if (src_dtype_ == DT_COMPLEX128) {
+ work_ = GetGpuCastFromComplex128(dst_dtype_);
+ } else if (src_dtype_ == DT_BFLOAT16) {
+ work_ = GetGpuCastFromBfloat(dst_dtype_);
+ }
+
+ return work_ == nullptr ? Unimplemented() : Status::OK();
}
};
+#endif // GOOGLE_CUDA
#undef CAST_CASE
@@ -236,7 +207,6 @@ REGISTER_CAST_GPU(bfloat16, float);
#endif // GOOGLE_CUDA
#undef CURRY_TYPES2
-#undef CURRY_TYPES3
// HostCast differs from Cast in that its input and output are in host memory.
REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
new file mode 100644
index 0000000000..cb7cc81937
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -0,0 +1,145 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/cast_op.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename O, typename I>
+struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
+ void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o,
+ typename TTypes<I>::ConstFlat i) {
+ o.device(d) = i.template cast<O>();
+ }
+};
+
+} // namespace functor
+
+#define CURRY_TYPES3(FN, arg0, arg1) \
+ FN(arg0, arg1, bool); \
+ FN(arg0, arg1, uint8); \
+ FN(arg0, arg1, int8); \
+ FN(arg0, arg1, uint16); \
+ FN(arg0, arg1, int16); \
+ FN(arg0, arg1, int32); \
+ FN(arg0, arg1, int64); \
+ FN(arg0, arg1, Eigen::half); \
+ FN(arg0, arg1, float); \
+ FN(arg0, arg1, double); \
+ FN(arg0, arg1, std::complex<float>); \
+ FN(arg0, arg1, std::complex<double>)
+
+#define CAST_CASE(DEVICE, IN, OUT) \
+ if (DataTypeToEnum<OUT>::value == dst_dtype) { \
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
+ functor::CastFunctor<DEVICE, OUT, IN> func; \
+ func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
+ }; \
+ }
+
+// The functions below are implemented in the cast_op_impl_*.cc files.
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromBool(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromHalf(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromFloat(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromDouble(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromComplex64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromComplex128(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromBfloat(DataType dst_dtype);
+
+#if GOOGLE_CUDA
+// Same, for GPU.
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromBool(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt8(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt16(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromHalf(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromFloat(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromDouble(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromComplex64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromComplex128(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromBfloat(DataType dst_dtype);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
new file mode 100644
index 0000000000..a06f815899
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
@@ -0,0 +1,55 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromBfloat(DataType dst_dtype) {
+ if (dst_dtype == DT_FLOAT) {
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ int64 N = out->NumElements();
+ auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ auto work = [&inp, &out](int64 start, int64 end) {
+ BFloat16ToFloat(inp.flat<bfloat16>().data() + start,
+ out->flat<float>().data() + start, end - start);
+ };
+ Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work);
+ };
+ }
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromBfloat(DataType dst_dtype) {
+ if (dst_dtype == DT_FLOAT) {
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ functor::CastFunctor<GPUDevice, float, bfloat16> func;
+ func(ctx->eigen_device<GPUDevice>(), out->flat<float>(),
+ inp.flat<bfloat16>());
+ };
+ }
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc
new file mode 100644
index 0000000000..92fee89a47
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromBool(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromBool(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, bool);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc
new file mode 100644
index 0000000000..c428679d7c
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromComplex128(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromComplex128(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<double>);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc
new file mode 100644
index 0000000000..07b46551b2
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromComplex64(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromComplex64(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<float>);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc
new file mode 100644
index 0000000000..fd20061d21
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_double.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromDouble(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, double);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromDouble(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, double);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc
new file mode 100644
index 0000000000..71e63fbff0
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_float.cc
@@ -0,0 +1,52 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromFloat(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, float);
+ if (dst_dtype == DT_BFLOAT16) {
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ int64 N = out->NumElements();
+ auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ auto work = [&inp, &out](int64 start, int64 end) {
+ FloatToBFloat16(inp.flat<float>().data() + start,
+ out->flat<bfloat16>().data() + start, end - start);
+ };
+ Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work);
+ };
+ }
+
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromFloat(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, float);
+ CAST_CASE(GPUDevice, float, bfloat16);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc
new file mode 100644
index 0000000000..e89d4646d7
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_half.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromHalf(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromHalf(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, Eigen::half);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc
new file mode 100644
index 0000000000..3c2d6185e3
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_int16.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt16(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt16(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int16);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc
new file mode 100644
index 0000000000..0fc6e16afe
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_int32.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt32(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt32(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int32);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc
new file mode 100644
index 0000000000..b5571b19a5
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_int64.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt64(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt64(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int64);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc
new file mode 100644
index 0000000000..62971fa95c
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_int8.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt8(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt8(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int8);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc
new file mode 100644
index 0000000000..529d9758f0
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint16(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, uint16);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint16(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, uint16);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc
new file mode 100644
index 0000000000..1a5025b054
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint8(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint8(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, GPUDevice, uint8);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow