aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-10-09 10:07:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-09 10:11:37 -0700
commit5bba158bbeea684c3e87de28a61004dbef28e00d (patch)
treef0e2343c66d8846edb101276c97e1e41091fbfd3
parent7e2b50d8490f573b470ca97bd06a4677830db738 (diff)
Print numpy value for variables when in Eager mode
PiperOrigin-RevId: 171549468
-rw-r--r--tensorflow/python/framework/ops.py24
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py2
-rw-r--r--tensorflow/python/ops/variables.py10
3 files changed, 21 insertions, 15 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index e6e6b9c6ca..0257f094d7 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -174,6 +174,17 @@ def uid():
return c_api.TFE_Py_UID()
+def numpy_text(tensor, is_repr=False):
+ """Human readable representation of a tensor's numpy value."""
+ if tensor.dtype.is_numpy_compatible:
+ text = repr(tensor.numpy()) if is_repr else str(tensor.numpy())
+ else:
+ text = "<unprintable>"
+ if "\n" in text:
+ text = "\n" + text
+ return text
+
+
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
class _TensorLike(object):
"""Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
@@ -590,15 +601,6 @@ class _EagerTensorBase(Tensor):
# performance-sensitive in some models.
return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access
- def _numpy_text(self, is_repr=False):
- if self.dtype.is_numpy_compatible:
- numpy_text = repr(self.numpy()) if is_repr else str(self.numpy())
- else:
- numpy_text = "<unprintable>"
- if "\n" in numpy_text:
- numpy_text = "\n" + numpy_text
- return numpy_text
-
def numpy(self):
"""Returns a numpy array with the same contents as the Tensor.
@@ -640,13 +642,13 @@ class _EagerTensorBase(Tensor):
raise NotImplementedError()
def __str__(self):
- return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(),
+ return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self),
self.shape,
self.dtype.name)
def __repr__(self):
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
- self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
+ self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True))
@staticmethod
def _override_operator(name, func):
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 7718710c69..f60ebf58f6 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -504,7 +504,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
def testRepr(self):
- var = variables.Variable(np.zeros((5, 5), np.float32), name='noop')
+ var = variables.Variable(np.zeros((5, 5), np.float32), name="noop")
self.assertEqual(
"<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
repr(var))
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index a27f26e303..90b4f25d81 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -213,9 +213,13 @@ class Variable(object):
constraint=constraint)
def __repr__(self):
- return "<tf.Variable '%s' shape=%s dtype=%s>" % (self.name,
- self.get_shape(),
- self.dtype.name)
+ if context.in_eager_mode():
+ return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
+ self.name, self.get_shape(), self.dtype.name,
+ ops.numpy_text(self.read_value(), is_repr=True))
+ else:
+ return "<tf.Variable '%s' shape=%s dtype=%s>" % (
+ self.name, self.get_shape(), self.dtype.name)
def _init_from_args(self,
initial_value=None,