aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-08-16 14:22:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 14:25:31 -0700
commit801228d3005b61af2f68223d4769269b1a5a074b (patch)
treef4cd7d4c3d5aeed2cf0aa851574eaa615be3e746
parent60d4441a0e8b67c070485af274949fbc033a2633 (diff)
Fix zeros_like for TensorLists.
PiperOrigin-RevId: 209046988
-rw-r--r--tensorflow/core/kernels/constant_op.cc5
-rw-r--r--tensorflow/core/kernels/list_kernels.h5
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py17
3 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 375819a8a2..426c404f43 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -259,8 +259,9 @@ class ZerosLikeOp : public OpKernel {
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
- Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT,
- TensorShape({}));
+ // DT_VARIANT tensors must be allocated on CPU since they wrap C++
+ // objects which can not be efficiently represented in GPU memory.
+ Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 42871c6113..f790fa45f1 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -266,9 +266,10 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
#undef DTYPE_CASE
default:
return errors::InvalidArgument(
- "Trying to compute zeros_like for unsupported dtype",
- out_tensor.dtype());
+ "Trying to compute zeros_like for unsupported dtype ",
+ DataTypeString(out_tensor.dtype()));
}
+ y->tensors.emplace_back(out_tensor);
}
return Status::OK();
}
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index bf82e08551..5893816369 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -421,6 +421,23 @@ class ListOpsTest(test_util.TensorFlowTestCase):
"Invalid data type at index 0"):
self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))
+ @test_util.run_in_graph_and_eager_modes
+ def testZerosLike(self):
+ l_empty = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=scalar_shape())
+ l_empty_zeros = array_ops.zeros_like(l_empty)
+ t_empty_zeros = list_ops.tensor_list_stack(
+ l_empty_zeros, element_dtype=dtypes.float32)
+
+ l_full = list_ops.tensor_list_push_back(l_empty, constant_op.constant(1.0))
+ l_full = list_ops.tensor_list_push_back(l_full, constant_op.constant(2.0))
+ l_full_zeros = array_ops.zeros_like(l_full)
+ t_full_zeros = list_ops.tensor_list_stack(
+ l_full_zeros, element_dtype=dtypes.float32)
+
+ self.assertAllEqual(self.evaluate(t_empty_zeros), [])
+ self.assertAllEqual(self.evaluate(t_full_zeros), [0.0, 0.0])
+
if __name__ == "__main__":
test.main()