aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op.cc
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-08-20 15:45:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 15:48:03 -0700
commit10a167c6c14767020387fe919328ef411a5bc0af (patch)
tree5991e4e90e7d236aca66ad8a7ae51cd44f12ba43 /tensorflow/core/kernels/cast_op.cc
parent135bd2714387eb5285fc9c621d9a8dc042d5f435 (diff)
Only UnsafeCopyFromInternal if src_type is different than external_src_dtype.
PiperOrigin-RevId: 209498329
Diffstat (limited to 'tensorflow/core/kernels/cast_op.cc')
-rw-r--r--tensorflow/core/kernels/cast_op.cc8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 0478c93280..3a72567655 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -98,7 +98,13 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
ctx->set_output(0, inp);
} else {
Tensor in;
- in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ if (external_src_dtype_ != src_dtype_) {
+ // If the type is a quantized type we need to do an UnsafeCopyFromInternal
+ // since the src_dtype_ is different from external_src_type_.
+ in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ } else {
+ in = inp;
+ }
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);