diff options
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 0478c93280..3a72567655 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -98,7 +98,13 @@ void CastOpBase::Compute(OpKernelContext* ctx) { ctx->set_output(0, inp); } else { Tensor in; - in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape()); + if (external_src_dtype_ != src_dtype_) { + // If the type is a quantized type we need to do an UnsafeCopyFromInternal + // since the src_dtype_ is different from external_src_type_. + in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape()); + } else { + in = inp; + } Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); out->set_dtype(dst_dtype_); |