diff options
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 145 |
1 files changed, 83 insertions, 62 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f9796ca679..8a1bcac0aa 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -39,6 +39,7 @@ from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.util import compat +from tensorflow.python.platform import logging def _convert_stack(stack): @@ -95,6 +96,22 @@ def _extract_stack(): return ret +def _as_graph_element(obj): + """Convert `obj` to a graph element if possible, otherwise return `None`. + + Args: + obj: Object to convert. + + Returns: + The result of `obj._as_graph_element()` if that method is available; + otherwise `None`. + """ + conv_fn = getattr(obj, "_as_graph_element", None) + if conv_fn and callable(conv_fn): + return conv_fn() + return None + + class Tensor(object): """Represents a value produced by an `Operation`. @@ -680,6 +697,7 @@ class IndexedSlices(object): def __init__(self, values, indices, dense_shape=None): """Creates an `IndexedSlices`.""" + _get_graph_from_inputs([values, indices, dense_shape]) self._values = values self._indices = indices self._dense_shape = dense_shape @@ -719,30 +737,15 @@ class IndexedSlices(object): """The `DType` of elements in this tensor.""" return self.values.dtype - def __str__(self): - return "IndexedSlices(indices=%s, values=%s)" % ( - self._indices, self._values) - - -def assert_same_graph(items, expected_graph=None): - """Asserts all items are from the same graph. + @property + def graph(self): + """The `Graph` that contains the values, indices, and shape tensors.""" + return self._values.graph - Args: - items: List of graph items (e.g., Variable, Tensor, SparseTensor, - Operation, or IndexedSlices). - expected_graph: Expected graph. If not specified, assert all tensors are - from the same graph. - Returns: - items, for chaining. - Raises: - ValueError: If any graphs do not match. - """ - for item in items: - if not expected_graph: - expected_graph = item.graph - elif expected_graph != item.graph: - raise ValueError("Items must be from the same graph.") - return items + def __str__(self): + return "IndexedSlices(indices=%s, values=%s%s)" % ( + self._indices, self._values, + (", dense_shape=%s" % self._dense_shape) if self._dense_shape else "") class SparseTensor(object): @@ -1106,7 +1109,7 @@ class Operation(object): """ if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) - assert_same_graph([self, tensor]) + _assert_same_graph(self, tensor) if dtype is None: dtype = tensor.dtype else: @@ -1138,7 +1141,7 @@ class Operation(object): """ if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) - assert_same_graph([self, tensor]) + _assert_same_graph(self, tensor) if dtype is None: dtype = tensor.dtype else: @@ -1166,7 +1169,7 @@ class Operation(object): """ if not isinstance(op, Operation): raise TypeError("op must be an Operation: %s" % op) - assert_same_graph([self, op]) + _assert_same_graph(self, op) self._control_inputs.append(op) self._recompute_node_def() @@ -1887,9 +1890,7 @@ class Graph(object): else: raise ValueError("allow_tensor and allow_operation can't both be False.") - conv_fn = getattr(obj, "_as_graph_element", None) - if conv_fn and callable(conv_fn): - obj = conv_fn() + obj = _as_graph_element(obj) or obj # If obj appears to be a name... if isinstance(obj, compat.bytes_or_text_types): @@ -2971,6 +2972,21 @@ def get_default_graph(): return _default_graph_stack.get_default() +def _assert_same_graph(original_item, item): + """Fail if the 2 items are from different graphs. + + Args: + original_item: Original item to check against. + item: Item to check. + + Raises: + ValueError: if graphs do not match. + """ + if original_item.graph is not item.graph: + raise ValueError( + "%s must be from the same graph as %s." % (item, original_item)) + + def _get_graph_from_inputs(op_input_list, graph=None): """Returns the appropriate graph to use for the given inputs. @@ -2986,8 +3002,8 @@ def _get_graph_from_inputs(op_input_list, graph=None): "op_input_list", we attempt to use the default graph. Args: - op_input_list: A list of inputs to an operation, which may include Tensor - and Operation objects. + op_input_list: A list of inputs to an operation, which may include `Tensor`, + `Operation`, and other objects that may be converted to a graph element. graph: (Optional) The explicit graph to use. Raises: @@ -3001,37 +3017,35 @@ def _get_graph_from_inputs(op_input_list, graph=None): The appropriate graph to use for the given inputs. """ op_input_list = tuple(op_input_list) # Handle generators correctly - - # 1. If the graph is specified explicitly, we validate that all of the inputs - # are compatible with that graph. - if graph is not None: - if not isinstance(graph, Graph): - raise TypeError("Input graph needs to be a Graph: %s" % graph) - for op_input in op_input_list: - if isinstance(op_input, Operation): - if op_input.graph is not graph: - raise ValueError("Operation %s is not from the passed-in graph" - % op_input) - elif isinstance(op_input, Tensor): - if op_input.graph is not graph: - raise ValueError("Tensor %s is not from the passed-in graph" - % op_input) - return graph - - # 2. Otherwise, we attempt to select a graph from one of the Operation- - # or Tensor-valued inputs. - original_input = None + if graph and not isinstance(graph, Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + + # 1. We validate that all of the inputs are from the same graph. This is + # either the supplied graph parameter, or the first one selected from one + # the graph-element-valued inputs. In the latter case, we hold onto + # that input in original_graph_element so we can provide a more + # informative error if a mismatch is found. + original_graph_element = None for op_input in op_input_list: - if isinstance(op_input, (Operation, Tensor)): - if original_input is None: - original_input = op_input - else: - assert_same_graph([original_input, op_input]) - if original_input is not None: - return original_input.graph + # Determine if this is a valid graph_element. + graph_element = None + if isinstance(op_input, (Operation, Tensor, SparseTensor, IndexedSlices)): + graph_element = op_input + else: + graph_element = _as_graph_element(op_input) - # 3. If all else fails, we use the default graph, which is always there. - return get_default_graph() + if graph_element: + if not graph: + original_graph_element = graph_element + graph = graph_element.graph + elif original_graph_element: + _assert_same_graph(original_graph_element, graph_element) + elif graph_element.graph is not graph: + raise ValueError( + "%s is not from the passed-in graph." % graph_element) + + # 2. If all else fails, we use the default graph, which is always there. + return graph or get_default_graph() class GraphKeys(object): @@ -3115,7 +3129,7 @@ def get_collection(key, scope=None): # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager -def op_scope(values, name, default_name): +def op_scope(values, name, default_name=None): """Returns a context manager for use when defining a Python op. This context manager validates that the given `values` are from the @@ -3140,10 +3154,17 @@ def op_scope(values, name, default_name): default_name: The default name to use if the `name` argument is `None`. Returns: - A context manager for use in defining a Python op. + A context manager for use in defining Python ops. Yields the name scope. + + Raises: + ValueError: if neither `name` nor `default_name` is provided. """ g = _get_graph_from_inputs(values) n = default_name if name is None else name + if n is None: + raise ValueError( + "At least one of name (%s) and default_name (%s) must be provided." % ( + name, default_name)) with g.as_default(), g.name_scope(n) as scope: yield scope # pylint: enable=g-doc-return-or-yield |