diff options
author | 2017-10-09 10:07:05 -0700 | |
---|---|---|
committer | 2017-10-09 10:11:37 -0700 | |
commit | 5bba158bbeea684c3e87de28a61004dbef28e00d (patch) | |
tree | f0e2343c66d8846edb101276c97e1e41091fbfd3 | |
parent | 7e2b50d8490f573b470ca97bd06a4677830db738 (diff) |
Print numpy value for variables when in Eager mode
PiperOrigin-RevId: 171549468
-rw-r--r-- | tensorflow/python/framework/ops.py | 24 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 10 |
3 files changed, 21 insertions, 15 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e6e6b9c6ca..0257f094d7 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -174,6 +174,17 @@ def uid(): return c_api.TFE_Py_UID() +def numpy_text(tensor, is_repr=False): + """Human readable representation of a tensor's numpy value.""" + if tensor.dtype.is_numpy_compatible: + text = repr(tensor.numpy()) if is_repr else str(tensor.numpy()) + else: + text = "<unprintable>" + if "\n" in text: + text = "\n" + text + return text + + # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. class _TensorLike(object): """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" @@ -590,15 +601,6 @@ class _EagerTensorBase(Tensor): # performance-sensitive in some models. return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access - def _numpy_text(self, is_repr=False): - if self.dtype.is_numpy_compatible: - numpy_text = repr(self.numpy()) if is_repr else str(self.numpy()) - else: - numpy_text = "<unprintable>" - if "\n" in numpy_text: - numpy_text = "\n" + numpy_text - return numpy_text - def numpy(self): """Returns a numpy array with the same contents as the Tensor. @@ -640,13 +642,13 @@ class _EagerTensorBase(Tensor): raise NotImplementedError() def __str__(self): - return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(), + return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape, self.dtype.name) def __repr__(self): return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % ( - self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True)) + self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True)) @staticmethod def _override_operator(name, func): diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 7718710c69..f60ebf58f6 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -504,7 +504,7 @@ class VariablesTestCase(test.TestCase): self.assertAllClose(np.ones((5, 5), np.float32), var.eval()) def testRepr(self): - var = variables.Variable(np.zeros((5, 5), np.float32), name='noop') + var = variables.Variable(np.zeros((5, 5), np.float32), name="noop") self.assertEqual( "<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>", repr(var)) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index a27f26e303..90b4f25d81 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -213,9 +213,13 @@ class Variable(object): constraint=constraint) def __repr__(self): - return "<tf.Variable '%s' shape=%s dtype=%s>" % (self.name, - self.get_shape(), - self.dtype.name) + if context.in_eager_mode(): + return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( + self.name, self.get_shape(), self.dtype.name, + ops.numpy_text(self.read_value(), is_repr=True)) + else: + return "<tf.Variable '%s' shape=%s dtype=%s>" % ( + self.name, self.get_shape(), self.dtype.name) def _init_from_args(self, initial_value=None, |