From 5bba158bbeea684c3e87de28a61004dbef28e00d Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Mon, 9 Oct 2017 10:07:05 -0700 Subject: Print numpy value for variables when in Eager mode PiperOrigin-RevId: 171549468 --- tensorflow/python/framework/ops.py | 24 +++++++++++++----------- tensorflow/python/kernel_tests/variables_test.py | 2 +- tensorflow/python/ops/variables.py | 10 +++++++--- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e6e6b9c6ca..0257f094d7 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -174,6 +174,17 @@ def uid(): return c_api.TFE_Py_UID() +def numpy_text(tensor, is_repr=False): + """Human readable representation of a tensor's numpy value.""" + if tensor.dtype.is_numpy_compatible: + text = repr(tensor.numpy()) if is_repr else str(tensor.numpy()) + else: + text = "" + if "\n" in text: + text = "\n" + text + return text + + # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. class _TensorLike(object): """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" @@ -590,15 +601,6 @@ class _EagerTensorBase(Tensor): # performance-sensitive in some models. return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access - def _numpy_text(self, is_repr=False): - if self.dtype.is_numpy_compatible: - numpy_text = repr(self.numpy()) if is_repr else str(self.numpy()) - else: - numpy_text = "" - if "\n" in numpy_text: - numpy_text = "\n" + numpy_text - return numpy_text - def numpy(self): """Returns a numpy array with the same contents as the Tensor. @@ -640,13 +642,13 @@ class _EagerTensorBase(Tensor): raise NotImplementedError() def __str__(self): - return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(), + return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape, self.dtype.name) def __repr__(self): return "" % ( - self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True)) + self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True)) @staticmethod def _override_operator(name, func): diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 7718710c69..f60ebf58f6 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -504,7 +504,7 @@ class VariablesTestCase(test.TestCase): self.assertAllClose(np.ones((5, 5), np.float32), var.eval()) def testRepr(self): - var = variables.Variable(np.zeros((5, 5), np.float32), name='noop') + var = variables.Variable(np.zeros((5, 5), np.float32), name="noop") self.assertEqual( "", repr(var)) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index a27f26e303..90b4f25d81 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -213,9 +213,13 @@ class Variable(object): constraint=constraint) def __repr__(self): - return "" % (self.name, - self.get_shape(), - self.dtype.name) + if context.in_eager_mode(): + return "" % ( + self.name, self.get_shape(), self.dtype.name, + ops.numpy_text(self.read_value(), is_repr=True)) + else: + return "" % ( + self.name, self.get_shape(), self.dtype.name) def _init_from_args(self, initial_value=None, -- cgit v1.2.3