diff options
author | 2018-08-20 15:45:00 -0700 | |
---|---|---|
committer | 2018-08-20 15:48:03 -0700 | |
commit | 10a167c6c14767020387fe919328ef411a5bc0af (patch) | |
tree | 5991e4e90e7d236aca66ad8a7ae51cd44f12ba43 /tensorflow/core/kernels/cast_op.cc | |
parent | 135bd2714387eb5285fc9c621d9a8dc042d5f435 (diff) |
Only UnsafeCopyFromInternal if src_type is different than external_src_dtype.
PiperOrigin-RevId: 209498329
Diffstat (limited to 'tensorflow/core/kernels/cast_op.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 0478c93280..3a72567655 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -98,7 +98,13 @@ void CastOpBase::Compute(OpKernelContext* ctx) { ctx->set_output(0, inp); } else { Tensor in; - in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape()); + if (external_src_dtype_ != src_dtype_) { + // If the type is a quantized type we need to do an UnsafeCopyFromInternal + // since the src_dtype_ is different from external_src_type_. + in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape()); + } else { + in = inp; + } Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); out->set_dtype(dst_dtype_); |