aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-26 15:42:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 15:47:02 -0700
commit725206e677a9f1e343319293a347862335ff776b (patch)
treeccf85489f300bc57509ab840939cecde19217f56
parent122ad249a8928a5136d4fd48d75be85f154a8c4c (diff)
[TF:XLA] Register the _HostCast operator on XlaDevice subclasses.
Declare CpuCastOp and CastOpBase in the cast_op.h header so they can be used from XlaDevice. PiperOrigin-RevId: 170121111
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h4
-rw-r--r--tensorflow/core/kernels/cast_op.cc129
-rw-r--r--tensorflow/core/kernels/cast_op.h29
4 files changed, 91 insertions, 72 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index e366db248a..13bebf43bc 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -154,6 +154,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:identity_op",
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 8699006ebc..498d25cf56 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/identity_op.h"
@@ -53,6 +54,9 @@ class XlaDeviceDummyOp : public OpKernel {
Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \
REGISTER_KERNEL_BUILDER( \
Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_HostCast").Device(DEVICE).HostMemory("x").HostMemory("y"), \
+ CpuCastOp); \
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 8bad488482..f16abb2b79 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -52,86 +52,71 @@ typedef Eigen::SyclDevice SYCLDevice;
FN(arg0, std::complex<float>); \
FN(arg0, std::complex<double>)
-class CastOpBase : public OpKernel {
- public:
- explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
+CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
+}
+
+void CastOpBase::Compute(OpKernelContext* ctx) {
+ const Tensor& inp = ctx->input(0);
+ if (work_ == nullptr) {
+ ctx->set_output(0, inp);
+ } else {
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
+ work_(ctx, inp, out);
}
+}
- void Compute(OpKernelContext* ctx) override {
- const Tensor& inp = ctx->input(0);
- if (work_ == nullptr) {
- ctx->set_output(0, inp);
- } else {
- Tensor* out = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
- work_(ctx, inp, out);
- }
- }
+Status CastOpBase::Unimplemented() {
+ return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
+ DataTypeString(dst_dtype_), " is not supported");
+}
- protected:
- DataType src_dtype_;
- DataType dst_dtype_;
- std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
+CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
+ OP_REQUIRES_OK(ctx, Prepare());
+}
- Status Unimplemented() {
- return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
- DataTypeString(dst_dtype_),
- " is not supported");
+Status CpuCastOp::Prepare() {
+ if (src_dtype_ == dst_dtype_) {
+ work_ = nullptr; // Identity
+ return Status::OK();
}
-
- TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
-};
-
-class CpuCastOp : public CastOpBase {
- public:
- explicit CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
- OP_REQUIRES_OK(ctx, Prepare());
+ if (src_dtype_ == DT_BOOL) {
+ work_ = GetCpuCastFromBool(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT8) {
+ work_ = GetCpuCastFromUint8(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetCpuCastFromInt8(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT16) {
+ work_ = GetCpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_INT16) {
+ work_ = GetCpuCastFromInt16(dst_dtype_);
+ } else if (src_dtype_ == DT_INT32) {
+ work_ = GetCpuCastFromInt32(dst_dtype_);
+ } else if (src_dtype_ == DT_INT64) {
+ work_ = GetCpuCastFromInt64(dst_dtype_);
+ } else if (src_dtype_ == DT_HALF) {
+ work_ = GetCpuCastFromHalf(dst_dtype_);
+ } else if (src_dtype_ == DT_FLOAT) {
+ work_ = GetCpuCastFromFloat(dst_dtype_);
+ } else if (src_dtype_ == DT_DOUBLE) {
+ work_ = GetCpuCastFromDouble(dst_dtype_);
+ } else if (src_dtype_ == DT_COMPLEX64) {
+ work_ = GetCpuCastFromComplex64(dst_dtype_);
+ } else if (src_dtype_ == DT_COMPLEX128) {
+ work_ = GetCpuCastFromComplex128(dst_dtype_);
+ } else if (src_dtype_ == DT_BFLOAT16) {
+ work_ = GetCpuCastFromBfloat(dst_dtype_);
}
- private:
- Status Prepare() {
- if (src_dtype_ == dst_dtype_) {
- work_ = nullptr; // Identity
- return Status::OK();
- }
- if (src_dtype_ == DT_BOOL) {
- work_ = GetCpuCastFromBool(dst_dtype_);
- } else if (src_dtype_ == DT_UINT8) {
- work_ = GetCpuCastFromUint8(dst_dtype_);
- } else if (src_dtype_ == DT_INT8) {
- work_ = GetCpuCastFromInt8(dst_dtype_);
- } else if (src_dtype_ == DT_UINT16) {
- work_ = GetCpuCastFromUint16(dst_dtype_);
- } else if (src_dtype_ == DT_INT16) {
- work_ = GetCpuCastFromInt16(dst_dtype_);
- } else if (src_dtype_ == DT_INT32) {
- work_ = GetCpuCastFromInt32(dst_dtype_);
- } else if (src_dtype_ == DT_INT64) {
- work_ = GetCpuCastFromInt64(dst_dtype_);
- } else if (src_dtype_ == DT_HALF) {
- work_ = GetCpuCastFromHalf(dst_dtype_);
- } else if (src_dtype_ == DT_FLOAT) {
- work_ = GetCpuCastFromFloat(dst_dtype_);
- } else if (src_dtype_ == DT_DOUBLE) {
- work_ = GetCpuCastFromDouble(dst_dtype_);
- } else if (src_dtype_ == DT_COMPLEX64) {
- work_ = GetCpuCastFromComplex64(dst_dtype_);
- } else if (src_dtype_ == DT_COMPLEX128) {
- work_ = GetCpuCastFromComplex128(dst_dtype_);
- } else if (src_dtype_ == DT_BFLOAT16) {
- work_ = GetCpuCastFromBfloat(dst_dtype_);
- }
-
- // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
- // bottleneck, we could probably implement specialized support for
- // vectorized versions (not the least based on F16C for Haswell
- // or newer).
+ // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
+ // bottleneck, we could probably implement specialized support for
+ // vectorized versions (not the least based on F16C for Haswell
+ // or newer).
- return work_ == nullptr ? Unimplemented() : Status::OK();
- }
-};
+ return work_ == nullptr ? Unimplemented() : Status::OK();
+}
#if GOOGLE_CUDA
class GpuCastOp : public CastOpBase {
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index 5c24f164a4..379b5b5e81 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -18,11 +18,40 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+
+// Common base class of Cast kernels
+class CastOpBase : public OpKernel {
+ public:
+ explicit CastOpBase(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ DataType src_dtype_;
+ DataType dst_dtype_;
+ std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
+
+ Status Unimplemented();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
+};
+
+// CPU implementation of Cast
+class CpuCastOp : public CastOpBase {
+ public:
+ explicit CpuCastOp(OpKernelConstruction* ctx);
+
+ private:
+ Status Prepare();
+};
+
namespace functor {
template <typename Device, typename Tout, typename Tin>