diff options
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e6e6b9c6ca..0257f094d7 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -174,6 +174,17 @@ def uid(): return c_api.TFE_Py_UID() +def numpy_text(tensor, is_repr=False): + """Human readable representation of a tensor's numpy value.""" + if tensor.dtype.is_numpy_compatible: + text = repr(tensor.numpy()) if is_repr else str(tensor.numpy()) + else: + text = "<unprintable>" + if "\n" in text: + text = "\n" + text + return text + + # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. class _TensorLike(object): """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" @@ -590,15 +601,6 @@ class _EagerTensorBase(Tensor): # performance-sensitive in some models. return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access - def _numpy_text(self, is_repr=False): - if self.dtype.is_numpy_compatible: - numpy_text = repr(self.numpy()) if is_repr else str(self.numpy()) - else: - numpy_text = "<unprintable>" - if "\n" in numpy_text: - numpy_text = "\n" + numpy_text - return numpy_text - def numpy(self): """Returns a numpy array with the same contents as the Tensor. @@ -640,13 +642,13 @@ class _EagerTensorBase(Tensor): raise NotImplementedError() def __str__(self): - return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(), + return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape, self.dtype.name) def __repr__(self): return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % ( - self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True)) + self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True)) @staticmethod def _override_operator(name, func): |