diff options
author | 2016-08-15 13:05:12 -0800 | |
---|---|---|
committer | 2016-08-15 14:17:19 -0700 | |
commit | 6da951a4e0b8744b331a91ac2c398efcc4af6bed (patch) | |
tree | 70c76a1c2ea6085ef3137fd7866f3749a869b8ae /tensorflow/core | |
parent | 25236c29a246d7eebaf0be37fa38ad2782699b21 (diff) |
Shard cast_op implementations for to improve build times.
Change: 130321753
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/kernels/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 160 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl.h | 145 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_bfloat.cc | 55 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_bool.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_complex128.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_complex64.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_double.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_float.cc | 52 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_half.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_int16.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_int32.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_int64.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_int8.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_uint16.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_uint8.cc | 37 |
16 files changed, 738 insertions, 95 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a61c2d8783..7f8b2b439d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1830,6 +1830,20 @@ filegroup( "bounds_check.h", "cast_op.cc", "cast_op.h", + "cast_op_impl.h", + "cast_op_impl_bfloat.cc", + "cast_op_impl_bool.cc", + "cast_op_impl_complex128.cc", + "cast_op_impl_complex64.cc", + "cast_op_impl_double.cc", + "cast_op_impl_float.cc", + "cast_op_impl_half.cc", + "cast_op_impl_int16.cc", + "cast_op_impl_int32.cc", + "cast_op_impl_int64.cc", + "cast_op_impl_int8.cc", + "cast_op_impl_uint16.cc", + "cast_op_impl_uint8.cc", "concat_lib.h", "concat_lib_cpu.cc", "concat_lib_cpu.h", diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index c08995b464..ab82c247d6 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -28,23 +28,13 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/kernels/cast_op_impl.h" + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -namespace functor { - -template <typename O, typename I> -struct CastFunctor<CPUDevice, O, I> { - void operator()(const CPUDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - o.device(d) = i.template cast<O>(); - } -}; - -} // namespace functor - #define CURRY_TYPES2(FN, arg0) \ FN(arg0, bool); \ FN(arg0, uint8); \ @@ -59,30 +49,6 @@ struct CastFunctor<CPUDevice, O, I> { FN(arg0, std::complex<float>); \ FN(arg0, std::complex<double>) -#define CURRY_TYPES3(FN, arg0, arg1) \ - FN(arg0, arg1, bool); \ - FN(arg0, arg1, uint8); \ - FN(arg0, arg1, int8); \ - FN(arg0, arg1, uint16); \ - FN(arg0, arg1, int16); \ - FN(arg0, arg1, int32); \ - FN(arg0, arg1, int64); \ - FN(arg0, arg1, Eigen::half); \ - FN(arg0, arg1, float); \ - FN(arg0, arg1, double); \ - FN(arg0, arg1, std::complex<float>); \ - FN(arg0, arg1, std::complex<double>) - -#define CAST_CASE(DEVICE, IN, OUT) \ - if (DataTypeToEnum<IN>::value == src_dtype_ && \ - DataTypeToEnum<OUT>::value == dst_dtype_) { \ - work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \ - functor::CastFunctor<DEVICE, OUT, IN> func; \ - func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \ - }; \ - return Status::OK(); \ - } - class CastOpBase : public OpKernel { public: explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -106,7 +72,6 @@ class CastOpBase : public OpKernel { DataType dst_dtype_; std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; - virtual Status Prepare() = 0; Status Unimplemented() { return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", DataTypeString(dst_dtype_), @@ -122,88 +87,94 @@ class CpuCastOp : public CastOpBase { OP_REQUIRES_OK(ctx, Prepare()); } - protected: - Status Prepare() override { + private: + Status Prepare() { if (src_dtype_ == dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } - CURRY_TYPES3(CAST_CASE, CPUDevice, bool); - CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); - CURRY_TYPES3(CAST_CASE, CPUDevice, int8); - CURRY_TYPES3(CAST_CASE, CPUDevice, uint16); - CURRY_TYPES3(CAST_CASE, CPUDevice, int16); - CURRY_TYPES3(CAST_CASE, CPUDevice, int32); - CURRY_TYPES3(CAST_CASE, CPUDevice, int64); - CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half); - CURRY_TYPES3(CAST_CASE, CPUDevice, float); - CURRY_TYPES3(CAST_CASE, CPUDevice, double); - CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>); - CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>); - - if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) { - work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { - int64 N = out->NumElements(); - auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); - auto work = [&inp, &out](int64 start, int64 end) { - BFloat16ToFloat(inp.flat<bfloat16>().data() + start, - out->flat<float>().data() + start, end - start); - }; - Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work); - }; - return Status::OK(); - } - if (src_dtype_ == DT_FLOAT && dst_dtype_ == DT_BFLOAT16) { - work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { - int64 N = out->NumElements(); - auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); - auto work = [&inp, &out](int64 start, int64 end) { - FloatToBFloat16(inp.flat<float>().data() + start, - out->flat<bfloat16>().data() + start, end - start); - }; - Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work); - }; - return Status::OK(); + if (src_dtype_ == DT_BOOL) { + work_ = GetCpuCastFromBool(dst_dtype_); + } else if (src_dtype_ == DT_UINT8) { + work_ = GetCpuCastFromUint8(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetCpuCastFromInt8(dst_dtype_); + } else if (src_dtype_ == DT_UINT16) { + work_ = GetCpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_INT16) { + work_ = GetCpuCastFromInt16(dst_dtype_); + } else if (src_dtype_ == DT_INT32) { + work_ = GetCpuCastFromInt32(dst_dtype_); + } else if (src_dtype_ == DT_INT64) { + work_ = GetCpuCastFromInt64(dst_dtype_); + } else if (src_dtype_ == DT_HALF) { + work_ = GetCpuCastFromHalf(dst_dtype_); + } else if (src_dtype_ == DT_FLOAT) { + work_ = GetCpuCastFromFloat(dst_dtype_); + } else if (src_dtype_ == DT_DOUBLE) { + work_ = GetCpuCastFromDouble(dst_dtype_); + } else if (src_dtype_ == DT_COMPLEX64) { + work_ = GetCpuCastFromComplex64(dst_dtype_); + } else if (src_dtype_ == DT_COMPLEX128) { + work_ = GetCpuCastFromComplex128(dst_dtype_); + } else if (src_dtype_ == DT_BFLOAT16) { + work_ = GetCpuCastFromBfloat(dst_dtype_); } // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a // bottleneck, we could probably implement specialized support for // vectorized versions (not the least based on F16C for Haswell - // or newer) here. + // or newer). - return Unimplemented(); + return work_ == nullptr ? Unimplemented() : Status::OK(); } }; +#if GOOGLE_CUDA class GpuCastOp : public CastOpBase { public: explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { OP_REQUIRES_OK(ctx, Prepare()); } - protected: - Status Prepare() override { + private: + Status Prepare() { if (src_dtype_ == dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } - CURRY_TYPES3(CAST_CASE, GPUDevice, bool); - CURRY_TYPES3(CAST_CASE, GPUDevice, uint8); - CURRY_TYPES3(CAST_CASE, GPUDevice, int8); - CURRY_TYPES3(CAST_CASE, GPUDevice, uint16); - CURRY_TYPES3(CAST_CASE, GPUDevice, int16); - CURRY_TYPES3(CAST_CASE, GPUDevice, int32); - CURRY_TYPES3(CAST_CASE, GPUDevice, int64); - CURRY_TYPES3(CAST_CASE, GPUDevice, Eigen::half); - CURRY_TYPES3(CAST_CASE, GPUDevice, float); - CURRY_TYPES3(CAST_CASE, GPUDevice, double); - CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<float>); - CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<double>); - CAST_CASE(GPUDevice, float, bfloat16); - CAST_CASE(GPUDevice, bfloat16, float); - return Unimplemented(); + if (src_dtype_ == DT_BOOL) { + work_ = GetGpuCastFromBool(dst_dtype_); + } else if (src_dtype_ == DT_UINT8) { + work_ = GetGpuCastFromUint8(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetGpuCastFromInt8(dst_dtype_); + } else if (src_dtype_ == DT_UINT16) { + work_ = GetGpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_INT16) { + work_ = GetGpuCastFromInt16(dst_dtype_); + } else if (src_dtype_ == DT_INT32) { + work_ = GetGpuCastFromInt32(dst_dtype_); + } else if (src_dtype_ == DT_INT64) { + work_ = GetGpuCastFromInt64(dst_dtype_); + } else if (src_dtype_ == DT_HALF) { + work_ = GetGpuCastFromHalf(dst_dtype_); + } else if (src_dtype_ == DT_FLOAT) { + work_ = GetGpuCastFromFloat(dst_dtype_); + } else if (src_dtype_ == DT_DOUBLE) { + work_ = GetGpuCastFromDouble(dst_dtype_); + } else if (src_dtype_ == DT_COMPLEX64) { + work_ = GetGpuCastFromComplex64(dst_dtype_); + } else if (src_dtype_ == DT_COMPLEX128) { + work_ = GetGpuCastFromComplex128(dst_dtype_); + } else if (src_dtype_ == DT_BFLOAT16) { + work_ = GetGpuCastFromBfloat(dst_dtype_); + } + + return work_ == nullptr ? Unimplemented() : Status::OK(); } }; +#endif // GOOGLE_CUDA #undef CAST_CASE @@ -236,7 +207,6 @@ REGISTER_CAST_GPU(bfloat16, float); #endif // GOOGLE_CUDA #undef CURRY_TYPES2 -#undef CURRY_TYPES3 // HostCast differs from Cast in that its input and output are in host memory. REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h new file mode 100644 index 0000000000..cb7cc81937 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -0,0 +1,145 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/cast_op.h" + +namespace tensorflow { + +namespace functor { + +template <typename O, typename I> +struct CastFunctor<Eigen::ThreadPoolDevice, O, I> { + void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o, + typename TTypes<I>::ConstFlat i) { + o.device(d) = i.template cast<O>(); + } +}; + +} // namespace functor + +#define CURRY_TYPES3(FN, arg0, arg1) \ + FN(arg0, arg1, bool); \ + FN(arg0, arg1, uint8); \ + FN(arg0, arg1, int8); \ + FN(arg0, arg1, uint16); \ + FN(arg0, arg1, int16); \ + FN(arg0, arg1, int32); \ + FN(arg0, arg1, int64); \ + FN(arg0, arg1, Eigen::half); \ + FN(arg0, arg1, float); \ + FN(arg0, arg1, double); \ + FN(arg0, arg1, std::complex<float>); \ + FN(arg0, arg1, std::complex<double>) + +#define CAST_CASE(DEVICE, IN, OUT) \ + if (DataTypeToEnum<OUT>::value == dst_dtype) { \ + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \ + functor::CastFunctor<DEVICE, OUT, IN> func; \ + func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \ + }; \ + } + +// The functions below are implemented in the cast_op_impl_*.cc files. +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromBool(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint8(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt8(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint16(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt16(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt32(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromHalf(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromFloat(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromDouble(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromComplex64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromComplex128(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromBfloat(DataType dst_dtype); + +#if GOOGLE_CUDA +// Same, for GPU. +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromBool(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint8(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt8(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint16(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt16(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt32(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromHalf(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromFloat(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromDouble(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromComplex64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromComplex128(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromBfloat(DataType dst_dtype); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc new file mode 100644 index 0000000000..a06f815899 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc @@ -0,0 +1,55 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromBfloat(DataType dst_dtype) { + if (dst_dtype == DT_FLOAT) { + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + int64 N = out->NumElements(); + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + auto work = [&inp, &out](int64 start, int64 end) { + BFloat16ToFloat(inp.flat<bfloat16>().data() + start, + out->flat<float>().data() + start, end - start); + }; + Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work); + }; + } + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromBfloat(DataType dst_dtype) { + if (dst_dtype == DT_FLOAT) { + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + functor::CastFunctor<GPUDevice, float, bfloat16> func; + func(ctx->eigen_device<GPUDevice>(), out->flat<float>(), + inp.flat<bfloat16>()); + }; + } + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc new file mode 100644 index 0000000000..92fee89a47 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_bool.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromBool(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, bool); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromBool(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, bool); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc new file mode 100644 index 0000000000..c428679d7c --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromComplex128(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromComplex128(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<double>); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc new file mode 100644 index 0000000000..07b46551b2 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromComplex64(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromComplex64(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, std::complex<float>); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc new file mode 100644 index 0000000000..fd20061d21 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_double.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromDouble(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, double); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromDouble(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, double); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc new file mode 100644 index 0000000000..71e63fbff0 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_float.cc @@ -0,0 +1,52 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromFloat(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, float); + if (dst_dtype == DT_BFLOAT16) { + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + int64 N = out->NumElements(); + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + auto work = [&inp, &out](int64 start, int64 end) { + FloatToBFloat16(inp.flat<float>().data() + start, + out->flat<bfloat16>().data() + start, end - start); + }; + Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work); + }; + } + + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromFloat(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, float); + CAST_CASE(GPUDevice, float, bfloat16); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc new file mode 100644 index 0000000000..e89d4646d7 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_half.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromHalf(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromHalf(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, Eigen::half); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc new file mode 100644 index 0000000000..3c2d6185e3 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_int16.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt16(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, int16); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt16(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, int16); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc new file mode 100644 index 0000000000..0fc6e16afe --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_int32.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt32(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, int32); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt32(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, int32); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc new file mode 100644 index 0000000000..b5571b19a5 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_int64.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt64(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, int64); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt64(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, int64); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc new file mode 100644 index 0000000000..62971fa95c --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_int8.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt8(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, int8); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt8(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, int8); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc new file mode 100644 index 0000000000..529d9758f0 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint16(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, uint16); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint16(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, uint16); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc new file mode 100644 index 0000000000..1a5025b054 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint8(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint8(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, GPUDevice, uint8); + return nullptr; +} +#endif // GOOGLE_CUDA + +} // namespace tensorflow |