diff options
author | 2017-09-26 15:42:32 -0700 | |
---|---|---|
committer | 2017-09-26 15:47:02 -0700 | |
commit | 725206e677a9f1e343319293a347862335ff776b (patch) | |
tree | ccf85489f300bc57509ab840939cecde19217f56 | |
parent | 122ad249a8928a5136d4fd48d75be85f154a8c4c (diff) |
[TF:XLA] Register the _HostCast operator on XlaDevice subclasses.
Declare CpuCastOp and CastOpBase in the cast_op.h header so they can be used from XlaDevice.
PiperOrigin-RevId: 170121111
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_ops.h | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 129 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op.h | 29 |
4 files changed, 91 insertions, 72 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index e366db248a..13bebf43bc 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -154,6 +154,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:identity_op", diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 8699006ebc..498d25cf56 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/identity_op.h" @@ -53,6 +54,9 @@ class XlaDeviceDummyOp : public OpKernel { Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \ REGISTER_KERNEL_BUILDER( \ Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_HostCast").Device(DEVICE).HostMemory("x").HostMemory("y"), \ + CpuCastOp); \ REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \ REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 8bad488482..f16abb2b79 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -52,86 +52,71 @@ typedef Eigen::SyclDevice SYCLDevice; FN(arg0, std::complex<float>); \ FN(arg0, std::complex<double>) -class CastOpBase : public OpKernel { - public: - explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); +CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); +} + +void CastOpBase::Compute(OpKernelContext* ctx) { + const Tensor& inp = ctx->input(0); + if (work_ == nullptr) { + ctx->set_output(0, inp); + } else { + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); + work_(ctx, inp, out); } +} - void Compute(OpKernelContext* ctx) override { - const Tensor& inp = ctx->input(0); - if (work_ == nullptr) { - ctx->set_output(0, inp); - } else { - Tensor* out = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); - work_(ctx, inp, out); - } - } +Status CastOpBase::Unimplemented() { + return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", + DataTypeString(dst_dtype_), " is not supported"); +} - protected: - DataType src_dtype_; - DataType dst_dtype_; - std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; +CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { + OP_REQUIRES_OK(ctx, Prepare()); +} - Status Unimplemented() { - return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", - DataTypeString(dst_dtype_), - " is not supported"); +Status CpuCastOp::Prepare() { + if (src_dtype_ == dst_dtype_) { + work_ = nullptr; // Identity + return Status::OK(); } - - TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); -}; - -class CpuCastOp : public CastOpBase { - public: - explicit CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { - OP_REQUIRES_OK(ctx, Prepare()); + if (src_dtype_ == DT_BOOL) { + work_ = GetCpuCastFromBool(dst_dtype_); + } else if (src_dtype_ == DT_UINT8) { + work_ = GetCpuCastFromUint8(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetCpuCastFromInt8(dst_dtype_); + } else if (src_dtype_ == DT_UINT16) { + work_ = GetCpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_INT16) { + work_ = GetCpuCastFromInt16(dst_dtype_); + } else if (src_dtype_ == DT_INT32) { + work_ = GetCpuCastFromInt32(dst_dtype_); + } else if (src_dtype_ == DT_INT64) { + work_ = GetCpuCastFromInt64(dst_dtype_); + } else if (src_dtype_ == DT_HALF) { + work_ = GetCpuCastFromHalf(dst_dtype_); + } else if (src_dtype_ == DT_FLOAT) { + work_ = GetCpuCastFromFloat(dst_dtype_); + } else if (src_dtype_ == DT_DOUBLE) { + work_ = GetCpuCastFromDouble(dst_dtype_); + } else if (src_dtype_ == DT_COMPLEX64) { + work_ = GetCpuCastFromComplex64(dst_dtype_); + } else if (src_dtype_ == DT_COMPLEX128) { + work_ = GetCpuCastFromComplex128(dst_dtype_); + } else if (src_dtype_ == DT_BFLOAT16) { + work_ = GetCpuCastFromBfloat(dst_dtype_); } - private: - Status Prepare() { - if (src_dtype_ == dst_dtype_) { - work_ = nullptr; // Identity - return Status::OK(); - } - if (src_dtype_ == DT_BOOL) { - work_ = GetCpuCastFromBool(dst_dtype_); - } else if (src_dtype_ == DT_UINT8) { - work_ = GetCpuCastFromUint8(dst_dtype_); - } else if (src_dtype_ == DT_INT8) { - work_ = GetCpuCastFromInt8(dst_dtype_); - } else if (src_dtype_ == DT_UINT16) { - work_ = GetCpuCastFromUint16(dst_dtype_); - } else if (src_dtype_ == DT_INT16) { - work_ = GetCpuCastFromInt16(dst_dtype_); - } else if (src_dtype_ == DT_INT32) { - work_ = GetCpuCastFromInt32(dst_dtype_); - } else if (src_dtype_ == DT_INT64) { - work_ = GetCpuCastFromInt64(dst_dtype_); - } else if (src_dtype_ == DT_HALF) { - work_ = GetCpuCastFromHalf(dst_dtype_); - } else if (src_dtype_ == DT_FLOAT) { - work_ = GetCpuCastFromFloat(dst_dtype_); - } else if (src_dtype_ == DT_DOUBLE) { - work_ = GetCpuCastFromDouble(dst_dtype_); - } else if (src_dtype_ == DT_COMPLEX64) { - work_ = GetCpuCastFromComplex64(dst_dtype_); - } else if (src_dtype_ == DT_COMPLEX128) { - work_ = GetCpuCastFromComplex128(dst_dtype_); - } else if (src_dtype_ == DT_BFLOAT16) { - work_ = GetCpuCastFromBfloat(dst_dtype_); - } - - // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a - // bottleneck, we could probably implement specialized support for - // vectorized versions (not the least based on F16C for Haswell - // or newer). + // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a + // bottleneck, we could probably implement specialized support for + // vectorized versions (not the least based on F16C for Haswell + // or newer). - return work_ == nullptr ? Unimplemented() : Status::OK(); - } -}; + return work_ == nullptr ? Unimplemented() : Status::OK(); +} #if GOOGLE_CUDA class GpuCastOp : public CastOpBase { diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index 5c24f164a4..379b5b5e81 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -18,11 +18,40 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { + +// Common base class of Cast kernels +class CastOpBase : public OpKernel { + public: + explicit CastOpBase(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + protected: + DataType src_dtype_; + DataType dst_dtype_; + std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; + + Status Unimplemented(); + + TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); +}; + +// CPU implementation of Cast +class CpuCastOp : public CastOpBase { + public: + explicit CpuCastOp(OpKernelConstruction* ctx); + + private: + Status Prepare(); +}; + namespace functor { template <typename Device, typename Tout, typename Tin> |