aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r--tensorflow/python/framework/ops.py80
1 files changed, 39 insertions, 41 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index ad2e2993c1..ab4455534e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -617,16 +617,15 @@ class _EagerTensorBase(Tensor):
return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access
def numpy(self):
- """Returns a numpy array or a scalar with the same contents as the Tensor.
+ """Returns a numpy array with the same contents as the Tensor.
TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
buffer but instead always explicitly copy? Note that currently it may or may
not copy based on whether the numpy data is properly aligned or not.
Returns:
- A numpy array or a scalar. Numpy array may share memory with the
- Tensor object. Any changes to one may be reflected in the other. A scalar
- value is returned when self has rank 0.
+ A numpy array that may share memory with the Tensor object. Any changes
+ to one may be reflected in the other.
Raises:
ValueError: if the type of this Tensor is not representable in numpy.
@@ -864,6 +863,10 @@ def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
inputs, which allows those ops to accept numpy arrays, Python lists,
and scalars in addition to `Tensor` objects.
+ Note: This function diverges from default Numpy behavior for `float` and
+ `string` types when `None` is present in a Python list or scalar. Rather
+ than silently converting `None` values, an error will be thrown.
+
Args:
value: An object whose type has a registered `Tensor` conversion function.
dtype: Optional element type for the returned tensor. If missing, the
@@ -1641,15 +1644,13 @@ class Operation(object):
default_colocation_group = [
compat.as_bytes("loc:@%s" % self._node_def.name)
]
- try:
- class_attr = self.get_attr("_class")
- except ValueError:
+ if "_class" not in self._node_def.attr:
# This op has no explicit colocation group, so it is itself its
# own root of a colocation group.
return default_colocation_group
attr_groups = [
- class_name for class_name in class_attr
+ class_name for class_name in self.get_attr("_class")
if class_name.startswith(b"loc:@")
]
@@ -2064,19 +2065,16 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
- if _USE_C_API:
- buf = c_api.TF_NewBufferFromString(
- compat.as_bytes(attr_value.SerializeToString()))
- try:
- with errors.raise_exception_on_not_ok_status() as status:
- # pylint: disable=protected-access
- c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
- status)
- # pylint: enable=protected-access
- finally:
- c_api.TF_DeleteBuffer(buf)
- else:
- self._node_def.attr[attr_name].CopyFrom(attr_value)
+ if not _USE_C_API:
+ assert "_set_attr not supported with _USE_C_API == False"
+ return
+ buf = c_api.TF_NewBufferFromString(
+ compat.as_bytes(attr_value.SerializeToString()))
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status) # pylint: disable=protected-access
+ finally:
+ c_api.TF_DeleteBuffer(buf)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@@ -2090,24 +2088,25 @@ class Operation(object):
Raises:
ValueError: If this op does not have an attr with the given `name`.
"""
- fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
- if self._c_op:
+ if _USE_C_API:
try:
- with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
- data = c_api.TF_GetBuffer(buf)
- except errors.InvalidArgumentError as e:
- # Convert to ValueError for backwards compatibility.
- raise ValueError(str(e))
- x = attr_value_pb2.AttrValue()
- x.ParseFromString(data)
- else:
- if name not in self._node_def.attr:
- raise ValueError(
- "No attr named '" + name + "' in " + str(self._node_def))
- x = self._node_def.attr[name]
+ # TODO(b/65162920): remove this try/except block when all attrs are
+ # implemented to use the _set_attr method instead of node_def.attr.
+ with errors.raise_exception_on_not_ok_status() as status:
+ metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status)
+ with errors.raise_exception_on_not_ok_status() as status:
+ if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0:
+ return c_api.TF_OperationGetAttrInt(self._c_op, name, status)
+ except errors.InvalidArgumentError:
+ # Colocation ops are failing to find attrs begininning with "_*". They
+ # should fall through to the not-CAPI logic until the attribute is set
+ # via the C-API always.
+ pass
+ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
+ if name not in self._node_def.attr:
+ raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
+ x = self._node_def.attr[name]
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
return []
@@ -3107,10 +3106,9 @@ class Graph(object):
ret._set_device(colocation_op.device) # pylint: disable=protected-access
all_colocation_groups = sorted(set(all_colocation_groups))
- # pylint: disable=protected-access
- ret._set_attr("_class", attr_value_pb2.AttrValue(
- list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
- # pylint: enable=protected-access
+ ret.node_def.attr["_class"].CopyFrom(
+ attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
+ s=all_colocation_groups)))
# Sets "container" attribute if
# (1) self._container is not None