diff options
author | Saurabh Saxena <srbs@google.com> | 2018-08-16 14:22:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 14:25:31 -0700 |
commit | 801228d3005b61af2f68223d4769269b1a5a074b (patch) | |
tree | f4cd7d4c3d5aeed2cf0aa851574eaa615be3e746 | |
parent | 60d4441a0e8b67c070485af274949fbc033a2633 (diff) |
Fix zeros_like for TensorLists.
PiperOrigin-RevId: 209046988
-rw-r--r-- | tensorflow/core/kernels/constant_op.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/list_kernels.h | 5 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/list_ops_test.py | 17 |
3 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 375819a8a2..426c404f43 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -259,8 +259,9 @@ class ZerosLikeOp : public OpKernel { errors::InvalidArgument("ZerosLike non-scalar Tensor with " "dtype=DT_VARIANT is not supported.")); const Variant& v = input.scalar<Variant>()(); - Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT, - TensorShape({})); + // DT_VARIANT tensors must be allocated on CPU since they wrap C++ + // objects which can not be efficiently represented in GPU memory. + Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({})); Variant* out_v = &(out.scalar<Variant>()()); OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>( ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v)); diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 42871c6113..f790fa45f1 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -266,9 +266,10 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, #undef DTYPE_CASE default: return errors::InvalidArgument( - "Trying to compute zeros_like for unsupported dtype", - out_tensor.dtype()); + "Trying to compute zeros_like for unsupported dtype ", + DataTypeString(out_tensor.dtype())); } + y->tensors.emplace_back(out_tensor); } return Status::OK(); } diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index bf82e08551..5893816369 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -421,6 +421,23 @@ class ListOpsTest(test_util.TensorFlowTestCase): "Invalid data type at index 0"): self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4])) + @test_util.run_in_graph_and_eager_modes + def testZerosLike(self): + l_empty = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=scalar_shape()) + l_empty_zeros = array_ops.zeros_like(l_empty) + t_empty_zeros = list_ops.tensor_list_stack( + l_empty_zeros, element_dtype=dtypes.float32) + + l_full = list_ops.tensor_list_push_back(l_empty, constant_op.constant(1.0)) + l_full = list_ops.tensor_list_push_back(l_full, constant_op.constant(2.0)) + l_full_zeros = array_ops.zeros_like(l_full) + t_full_zeros = list_ops.tensor_list_stack( + l_full_zeros, element_dtype=dtypes.float32) + + self.assertAllEqual(self.evaluate(t_empty_zeros), []) + self.assertAllEqual(self.evaluate(t_full_zeros), [0.0, 0.0]) + if __name__ == "__main__": test.main() |