diff options
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 2985 |
1 files changed, 2985 insertions, 0 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py new file mode 100644 index 0000000000..0b0442cea1 --- /dev/null +++ b/tensorflow/python/framework/ops.py @@ -0,0 +1,2985 @@ +"""Classes and functions used to construct graphs.""" +# pylint: disable=g-bad-name +import collections +import contextlib +import copy +import linecache +import re +import sys +import threading +import weakref + +import tensorflow.python.platform + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import registry +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import types + + +def _convert_stack(stack): + """Converts a stack extracted using _extract_stack() to a traceback stack. + + Args: + stack: A list of n 4-tuples, (filename, lineno, name, frame_globals). + + Returns: + A list of n 4-tuples (filename, lineno, name, code), where the code tuple + element is calculated from the corresponding elements of the input tuple. + """ + ret = [] + for filename, lineno, name, frame_globals in stack: + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, frame_globals) + if line: + line = line.strip() + else: + line = None + ret.append((filename, lineno, name, line)) + return ret + + +# pylint: disable=line-too-long +def _extract_stack(): + """A lightweight re-implementation of traceback.extract_stack. + + NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for + each stack frame using linecache, which results in an abundance of stat() + calls. This implementation does not retrieve the code, and any consumer + should apply _convert_stack to the result to obtain a traceback that can + be formatted etc. using traceback methods. + + Returns: + A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to + the call stack of the current thread. + """ + # pylint: enable=line-too-long + try: + raise ZeroDivisionError + except ZeroDivisionError: + f = sys.exc_info()[2].tb_frame.f_back + ret = [] + while f is not None: + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + frame_globals = f.f_globals + ret.append((filename, lineno, name, frame_globals)) + f = f.f_back + ret.reverse() + return ret + + +class Tensor(object): + """Represents a value produced by an `Operation`. + + A `Tensor` is a symbolic handle to one of the outputs of an + `Operation`. It does not hold the values of that operation's output, + but instead provides a means of computing those values in a + TensorFlow [`Session`](client.md#Session). + + This class has two primary purposes: + + 1. A `Tensor` can be passed as an input to another `Operation`. + This builds a dataflow connection between operations, which + enables TensorFlow to execute an entire `Graph` that represents a + large, multi-step computation. + + 2. After the graph has been launched in a session, the value of the + `Tensor` can be computed by passing it to + [`Session.run()`](client.md#Session.run). + `t.eval()` is a shortcut for calling + `tf.get_default_session().run(t)`. + + In the following example, `c`, `d`, and `e` are symbolic `Tensor` + objects, whereas `result` is a numpy array that stores a concrete + value: + + ```python + # Build a dataflow graph. + c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) + e = tf.matmul(c, d) + + # Construct a `Session` to execut the graph. + sess = tf.Session() + + # Execute the graph and store the value that `e` represents in `result`. + result = sess.run(e) + ``` + + @@dtype + @@name + @@value_index + @@graph + @@op + @@consumers + + @@eval + + @@get_shape + @@set_shape + + """ + + # List of Python operators that we allow to override. + OVERLOADABLE_OPERATORS = { + # Binary. + "__add__", "__radd__", + "__sub__", "__rsub__", + "__mul__", "__rmul__", + "__div__", "__rdiv__", + "__truediv__", "__rtruediv__", + "__mod__", "__rmod__", + "__lt__", "__le__", + "__gt__", "__ge__", + "__and__", "__rand__", + "__or__", "__ror__", + "__xor__", "__rxor__", + "__getitem__", + # Unary. + "__invert__", + "__neg__", "__abs__"} + + def __init__(self, op, value_index, dtype): + """Creates a new `Tensor`. + + Args: + op: An `Operation`. `Operation` that computes this tensor. + value_index: An `int`. Index of the operation's endpoint that produces + this tensor. + dtype: A `types.DType`. Type of data stored in this tensor. + + Raises: + TypeError: If the op is not an `Operation`. + """ + if not isinstance(op, Operation): + raise TypeError("op needs to be an Operation: %s" % op) + self._op = op + self._value_index = value_index + self._dtype = types.as_dtype(dtype) + self._shape = tensor_shape.unknown_shape() + # List of operations that use this Tensor as input. We maintain this list + # to easily navigate a computation graph. + self._consumers = [] + + @property + def op(self): + """The `Operation` that produces this tensor as an output.""" + return self._op + + @property + def dtype(self): + """The `DType` of elements in this tensor.""" + return self._dtype + + @property + def graph(self): + """The `Graph` that contains this tensor.""" + return self._op.graph + + @property + def name(self): + """The string name of this tensor.""" + if not self._op.name: + raise ValueError("Operation was not named: %s" % self._op) + return "%s:%d" % (self._op.name, self._value_index) + + @property + def device(self): + """The name of the device on which this tensor will be produced, or None.""" + return self._op.device + + def _shape_as_list(self): + if self._shape.ndims is not None: + return [dim.value for dim in self._shape.dims] + else: + return None + + def get_shape(self): + """Returns the `TensorShape` that represents the shape of this tensor. + + The shape is computed using shape inference functions that are + registered for each `Operation` type using `tf.RegisterShape`. + See [`TensorShape`](framework.md#TensorShape) for more details of what a shape + represents. + + The inferred shape of a tensor is used to provide shape + information without having to launch the graph in a session. This + can be used for debugging, and providing early error messages. For + example: + + ```python + c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + print c.get_shape() + ==> TensorShape([Dimension(2), Dimension(3)]) + + d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]) + + print d.get_shape() + ==> TensorShape([Dimension(4), Dimension(2)]) + + # Raises a ValueError, because `c` and `d` do not have compatible + # inner dimensions. + e = tf.matmul(c, d) + + f = tf.matmul(c, d, transpose_a=True, transpose_b=True) + + print f.get_shape() + ==> TensorShape([Dimension(3), Dimension(4)]) + ``` + + In some cases, the inferred shape may have unknown dimensions. If + the caller has additional information about the values of these + dimensions, `Tensor.set_shape()` can be used to augment the + inferred shape. + + Returns: + A `TensorShape` representing the shape of this tensor. + """ + return self._shape + + def set_shape(self, shape): + """Updates the shape of this tensor. + + This method can be called multiple times, and will merge the given + `shape` with the current shape of this tensor. It can be used to + provide additional information about the shape of this tensor that + cannot be inferred from the graph alone. For example, this can be used + to provide additional information about the shapes of images: + + ```python + _, image_data = tf.TFRecordReader(...).read(...) + image = tf.image.decode_png(image_data, channels=3) + + # The height and width dimensions of `image` are data dependent, and + # cannot be computed without executing the op. + print image.get_shape() + ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)]) + + # We know that each image in this dataset is 28 x 28 pixels. + image.set_shape([28, 28, 3]) + print image.get_shape() + ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) + ``` + + Args: + shape: A `TensorShape` representing the shape of this tensor. + + Raises: + ValueError: If `shape` is not compatible with the current shape of + this tensor. + """ + self._shape = self._shape.merge_with(shape) + + @property + def value_index(self): + """The index of this tensor in the outputs of its `Operation`.""" + return self._value_index + + def consumers(self): + """Returns a list of `Operation`s that consume this tensor. + + Returns: + A list of `Operation`s. + """ + return self._consumers + + def _add_consumer(self, consumer): + """Add a consumer to this tensor. + + Args: + consumer: an Operation. + + Raises: + TypeError: if the consumer is not an Operation. + """ + if not isinstance(consumer, Operation): + raise TypeError("Consumer must be an Operation: %s" % consumer) + self._consumers.append(consumer) + + def _as_node_def_input(self): + """Return a value to use for the NodeDef "input" attribute. + + The returned string can be used in a NodeDef "input" attribute + to indicate that the NodeDef uses this Tensor as input. + + Raises: + ValueError: if this Tensor's Operation does not have a name. + + Returns: + a string. + """ + if not self._op.name: + raise ValueError("Operation was not named: %s" % self._op) + if self._value_index == 0: + return self._op.name + else: + return "%s:%d" % (self._op.name, self._value_index) + + def __str__(self): + return "Tensor(\"%s\"%s%s%s)" % ( + self.name, + (", shape=%s" % self.get_shape()) + if self.get_shape().ndims is not None else "", + (", dtype=%s" % self._dtype.name) if self._dtype else "", + (", device=%s" % self.device) if self.device else "") + + def __hash__(self): + # Necessary to support Python's collection membership operators + return id(self) + + def __eq__(self, other): + # Necessary to support Python's collection membership operators + return id(self) == id(other) + + # NOTE(mrry): This enables the Tensor's overloaded "right" binary + # operators to run when the left operand is an ndarray, because it + # accords the Tensor class higher priority than an ndarray, or a + # numpy matrix. + # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ + # mechanism, which allows more control over how Tensors interact + # with ndarrays. + __array_priority__ = 100 + + @staticmethod + def _override_operator(operator, func): + """Overrides (string) operator on Tensors to call func. + + Args: + operator: the string name of the operator to override. + func: the function that replaces the overriden operator. + + Raises: + ValueError: If operator has already been overwritten, + or if operator is not allowed to be overwritten. + """ + if getattr(Tensor, operator, None) is not None: + # check to see if this is a default method-wrapper which will be true + # for the comparison operators. + if not isinstance(getattr(Tensor, operator, None), type(all.__call__)): + raise ValueError("operator %s cannot be overwritten again." % operator) + if operator not in Tensor.OVERLOADABLE_OPERATORS: + raise ValueError("Overriding %s is disallowed" % operator) + setattr(Tensor, operator, func) + + def __iter__(self): + """Dummy method to prevent iteration. Do not call. + + NOTE(mrry): If we register __getitem__ as an overloaded operator, + Python will valiantly attempt to iterate over the Tensor from 0 to + infinity. Declaring this method prevents this unintended + behavior. + + Raises: + TypeError: when invoked. + """ + raise TypeError("'Tensor' object is not iterable") + + def eval(self, feed_dict=None, session=None): + """Evaluates this tensor in a `Session`. + + Calling this method will execute all preceding operations that + produce the inputs needed for the operation that produces this + tensor. + + *N.B.* Before invoking `Tensor.eval()`, its graph must have been + launched in a session, and either a default session must be + available, or `session` must be specified explicitly. + + Args: + feed_dict: A dictionary that maps `Tensor` objects to feed values. + See [`Session.run()`](client.md#Session.run) for a description of + the valid feed values. + session: (Optional.) The `Session` to be used to evaluate this tensor. If + none, the default session will be used. + + Returns: + A numpy array corresponding to the value of this tensor. + + """ + return _eval_using_default_session(self, feed_dict, self.graph, session) + + +def _TensorTensorConversionFunction(t, dtype=None, name=None): + _ = name + if dtype and not dtype.is_compatible_with(t.dtype): + raise ValueError( + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" + % (dtype.name, t.dtype.name, str(t))) + return t + + +_tensor_conversion_func_registry = { + 0: [(Tensor, _TensorTensorConversionFunction)]} + + +def convert_to_tensor(value, dtype=None, name=None): + """Converts the given `value` to a `Tensor`. + + This function converts Python objects of various types to `Tensor` + objects. It accepts `Tensor` objects, numpy arrays, Python lists, + and Python scalars. For example: + + ```python + import numpy as np + array = np.random.rand((32, 100, 100)) + + def my_func(arg): + arg = tf.convert_to_tensor(arg, dtype=tf.float32) + return tf.matmul(arg, arg) + arg + + # The following calls are equivalent. + value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0])) + value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) + value_3 = my_func(numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32)) + ``` + + This function can be useful when composing a new operation in Python + (such as `my_func` in the example above). All standard Python op + constructors apply this function to each of their Tensor-valued + inputs, which allows those ops to accept numpy arrays, Python lists, + and scalars in addition to `Tensor` objects. + + Args: + value: An object whose type has a registered `Tensor` conversion function. + dtype: Optional element type for the returned tensor. If missing, the + type is inferred from the type of `value`. + name: Optional name to use if a new `Tensor` is created. + + Returns: + A `Tensor` based on `value`. + + Raises: + TypeError: If no conversion function is registered for `value`. + RuntimeError: If a registered conversion function returns an invalid value. + + """ + error_prefix = "" if name is None else "%s: " % name + if dtype is not None: + dtype = types.as_dtype(dtype) + for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()): + for base_type, conversion_func in funcs_at_priority: + if isinstance(value, base_type): + ret = conversion_func(value, dtype=dtype, name=name) + if not isinstance(ret, Tensor): + raise RuntimeError( + "%sConversion function %r for type %s returned non-Tensor: %r" + % (error_prefix, conversion_func, base_type, ret)) + if dtype and not dtype.is_compatible_with(ret.dtype): + raise RuntimeError( + "%sConversion function %r for type %s returned incompatible " + "dtype: requested = %s, actual = %s" + % (error_prefix, conversion_func, base_type, + dtype.name, ret.dtype.name)) + return ret + raise TypeError("%sCannot convert %r with type %s to Tensor: " + "no conversion function registered." + % (error_prefix, value, type(value))) + + +def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): + """Converts the given object to a `Tensor` or an `IndexedSlices`. + + If `value` is an `IndexedSlices` it is returned + unmodified. Otherwise, it is converted to a `Tensor` using + `convert_to_tensor()`. + + Args: + value: An `IndexedSlices` or an object that can be consumed by + `convert_to_tensor()`. + dtype: (Optional.) The required `DType` of the returned `Tensor` or + `IndexedSlices`. + name: (Optional.) A name to use if a new `Tensor` is created. + + Returns: + An `Tensor` or an `IndexedSlices` based on `value`. + + Raises: + ValueError: If `dtype` does not match the element type of `value`. + """ + if isinstance(value, IndexedSlices): + if dtype and not types.AsDType(dtype).is_compatible_with(value.dtype): + raise ValueError( + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" + % (types.AsDType(dtype).name, value.dtype.name, str(value))) + return value + else: + return convert_to_tensor(value, dtype, name) + + +def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): + """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. + + Args: + values: A list of `None`, `IndexedSlices`, or objects that can be consumed + by `convert_to_tensor()`. + dtype: (Optional.) The required `DType` of the returned `Tensor` + `IndexedSlices`. + + name: (Optional.) A name prefix to used when a new `Tensor` is + created, in which case element `i` will be given the name `name + + '_' + i`. + + Returns: + A list of `Tensor` and/or `IndexedSlices` objects. + + Raises: + TypeError: If no conversion function is registered for an element in + `values`. + RuntimeError: If a registered conversion function returns an invalid + value. + """ + if not isinstance(values, collections.Sequence): + raise TypeError("values must be a list.") + ret = [] + for i, value in enumerate(values): + if value is None: + ret.append(value) + else: + n = None if name is None else "%s_%d" % (name, i) + ret.append( + convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n)) + return ret + + +def register_tensor_conversion_function(base_type, conversion_func, + priority=100): + """Registers a function for converting objects of base_type to Tensor. + + The conversion function must have the following signature: + + def conversion_func(value, dtype=None, name=None): + # ... + + It must return a Tensor with the given dtype if specified. If the + conversion function creates a new Tensor, it should use the given + name if specified. All exceptions will be propagated to the caller. + + NOTE: The conversion functions will execute in order of priority, + followed by order of registration. To ensure that a conversion + function F runs before another conversion function G, ensure that + F is registered with a smaller priority than G. + + Args: + base_type: The base type or tuple of base types for all objects that + `conversion_func` accepts. + conversion_func: A function that converts instances of base_type to Tensor. + priority: Optional integer that indicates the priority for applying this + conversion function. Conversion functions with smaller priority values + run earlier than conversion functions with larger priority values. + Defaults to 100. + + Raises: + TypeError: If the arguments do not have the appropriate type. + + """ + if not (isinstance(base_type, type) or + (isinstance(base_type, tuple) + and all(isinstance(x, type) for x in base_type))): + raise TypeError("base_type must be a type or a tuple of types.") + if not callable(conversion_func): + raise TypeError("conversion_func must be callable.") + + try: + funcs_at_priority = _tensor_conversion_func_registry[priority] + except KeyError: + funcs_at_priority = [] + _tensor_conversion_func_registry[priority] = funcs_at_priority + funcs_at_priority.append((base_type, conversion_func)) + + +class IndexedSlices(object): + """A sparse representation of a set of tensor slices at given indices. + + This class is a simple wrapper for a pair of `Tensor` objects: + + * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. + * `indices`: A 1-D integer `Tensor` with shape `[D0]`. + + An `IndexedSlices` is typically used to represent a subset of a larger + tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. + The values in `indices` are the indices in the first dimension of + the slices that have been extracted from the larger tensor. + + The dense tensor `dense` represented by an `IndexedSlices` `slices` has + + ```python + dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] + ``` + + The `IndexedSlices` class is used principally in the definition of + gradients for operations that have sparse gradients + (e.g. [`tf.gather`](array_ops.md#gather)). + + Contrast this representation with + [`SparseTensor`](sparse_ops.md#SparseTensor), + which uses multi-dimensional indices and scalar values. + + @@__init__ + + @@values + @@indices + @@dense_shape + + @@name + @@dtype + @@device + @@op + """ + + def __init__(self, values, indices, dense_shape=None): + """Creates an `IndexedSlices`.""" + self._values = values + self._indices = indices + self._dense_shape = dense_shape + + @property + def values(self): + """A `Tensor` containing the values of the slices.""" + return self._values + + @property + def indices(self): + """A 1-D `Tensor` containing the indices of the slices.""" + return self._indices + + @property + def dense_shape(self): + """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" + return self._dense_shape + + @property + def name(self): + """The name of this `IndexedSlices`.""" + return self.values.name + + @property + def device(self): + """The name of the device on which `values` will be produced, or `None`.""" + return self.values.device + + @property + def op(self): + """The `Operation` that produces `values` as an output.""" + return self.values.op + + @property + def dtype(self): + """The `DType` of elements in this tensor.""" + return self.values.dtype + + def __str__(self): + return "IndexedSlices(indices=%s, values=%s)" % ( + self._indices, self._values) + + +def assert_same_graph(items, expected_graph=None): + """Asserts all items are from the same graph. + + Args: + items: List of graph items (e.g., Variable, Tensor, SparseTensor, + Operation, or IndexedSlices). + expected_graph: Expected graph. If not specified, assert all tensors are + from the same graph. + Returns: + items, for chaining. + Raises: + ValueError: If any graphs do not match. + """ + for item in items: + if not expected_graph: + expected_graph = item.graph + elif expected_graph != item.graph: + raise ValueError("Items must be from the same graph.") + return items + + +class SparseTensor(object): + """Represents a sparse tensor. + + Tensorflow represents a sparse tensor as three separate dense tensors: + `indices`, `values`, and `dense_shape`. In Python, the three tensors are + collected into a `SparseTensor` class for ease of use. If you have separate + `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor` + object before passing to the Ops below. + + Concretely, the sparse tensor `SparseTensor(values, indices, dense_shape)` is + + * `indices`: A 2-D int64 tensor of shape `[N, ndims]`. + * `values`: A 1-D tensor of any type and shape `[N]`. + * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`. + + where `N` and `ndims` are the number of values, and number of dimensions in + the `SparseTensor` respectively. + + The corresponding dense tensor satisfies + + ```python + dense.shape = dense_shape + dense[tuple(indices[i])] = values[i] + ``` + + By convention, `indices` should be sorted in row-major order (or equivalently + lexigraphic order on the tuples `indices[i]`). This is not enforced when + `SparseTensor` objects are constructed, but most Ops assume correct ordering. + If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the + misordered `SparseTensor`. + + Example: The sparse tensor + + ```python + SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4]) + ``` + + represents the dense tensor + + ```python + [[1, 0, 0, 0] + [0, 0, 2, 0] + [0, 0, 0, 0]] + ``` + + @@__init__ + @@indices + @@values + @@dtype + @@shape + @@graph + """ + + def __init__(self, indices, values, shape): + """Creates a `SparseTensor`. + + Args: + indices: A 2-D int64 tensor of shape `[N, ndims]`. + values: A 1-D tensor of any type and shape `[N]`. + dense_shape: A 1-D int64 tensor of shape `[ndims]`. + + Returns: + A `SparseTensor` + """ + with op_scope([indices, values, shape], None, "SparseTensor"): + indices = convert_to_tensor(indices, name="indices") + values = convert_to_tensor(values, name="values") + shape = convert_to_tensor(shape, name="shape") + self._indices = indices + self._values = values + self._shape = shape + + indices_shape = indices.get_shape().with_rank(2) + values_shape = values.get_shape().with_rank(1) + shape_shape = shape.get_shape().with_rank(1) + + # Assert number of rows in indices match the number of elements in values. + indices_shape[0].merge_with(values_shape[0]) + # Assert number of columns in indices matches the number of elements in + # shape. + indices_shape[1].merge_with(shape_shape[0]) + + @property + def indices(self): + """The indices of non-zero values in the represented dense tensor. + + Returns: + A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the + number of non-zero values in the tensor, and `ndims` is the rank. + """ + return self._indices + + @property + def values(self): + """The non-zero values in the represented dense tensor. + + Returns: + A 1-D Tensor of any data type. + """ + return self._values + + @property + def dtype(self): + """The `DType` of elements in this tensor.""" + return self._values.dtype + + @property + def shape(self): + """A 1-D Tensor of int64 representing the shape of the dense tensor.""" + return self._shape + + @property + def graph(self): + """The `Graph` that contains the index, value, and shape tensors.""" + return self._indices.graph + + def __str__(self): + return "SparseTensor(indices=%s, values=%s, shape=%s)" % ( + self._indices, self._values, self._shape) + + +SparseTensorValue = collections.namedtuple("SparseTensorValue", + ["indices", "values", "shape"]) + + +def _device_string(dev_spec): + if isinstance(dev_spec, pydev.Device): + return dev_spec.to_string() + else: + return dev_spec + + +def _NodeDef(op_type, name, device=None, attrs=None): + """Create a NodeDef proto. + + Args: + op_type: Value for the "op" attribute of the NodeDef proto. + name: Value for the "name" attribute of the NodeDef proto. + device: string, device, or function from NodeDef to string. + Value for the "device" attribute of the NodeDef proto. + attrs: optional list for the "attr" attribute of the NodeDef proto. + + Returns: + A graph_pb2.NodeDef protocol buffer. + """ + node_def = graph_pb2.NodeDef() + node_def.op = str(op_type) + node_def.name = str(name) + if attrs is not None: + for k, v in attrs.iteritems(): + node_def.attr[k].CopyFrom(v) + if device is not None: + if callable(device): + node_def.device = device(node_def) + else: + node_def.device = _device_string(device) + return node_def + + +# Copied from core/framework/node_def_util.cc +# TODO(mrry,josh11b): Consolidate this validation in C++ code. +_VALID_OP_NAME_REGEX = re.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*") + + +class Operation(object): + """Represents a graph node that performs computation on tensors. + + An `Operation` is a node in a TensorFlow `Graph` that takes zero or + more `Tensor` objects as input, and produces zero or more `Tensor` + objects as output. Objects of type `Operation` are created by + calling a Python op constructor (such as [`tf.matmul()`](math_ops.md#matmul)) + or [`Graph.create_op()`](framework.md#Graph.create_op). + + For example `c = tf.matmul(a, b)` creates an `Operation` of type + "MatMul" that takes tensors `a` and `b` as input, and produces `c` + as output. + + After the graph has been launched in a session, an `Operation` can + be executed by passing it to [`Session.run()`](client.md#Session.run). + `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. + + @@name + @@type + @@inputs + @@control_inputs + @@outputs + @@device + @@graph + + @@run + + @@get_attr + @@traceback + """ + + def __init__(self, node_def, g, inputs=None, output_types=None, + control_inputs=None, input_types=None, original_op=None, + op_def=None): + """Creates an `Operation`. + + NOTE: This constructor validates the name of the Operation (passed + as "node_def.name"). Valid Operation names match the following + regular expression: + + [A-Za-z0-9.][A-Za-z0-9_.\\-/]* + + Args: + node_def: graph_pb2.NodeDef. NodeDef for the Operation. + Used for attributes of graph_pb2.NodeDef, typically "name", + "op", and "device". The "input" attribute is irrelevant here + as it will be computed when generating the model. + g: Graph. The parent graph. + inputs: list of Tensor objects. The inputs to this Operation. + output_types: list of types_pb2.DataType. List of the types of the + Tensors computed by this operation. The length of this list indicates + the number of output endpoints of the Operation. + control_inputs: list of operations or tensors from which to have a + control dependency. + input_types: List of types_pb2.DataType representing the + types of the Tensors accepted by the Operation. By default + uses [x.dtype.base_dtype for x in inputs]. Operations that expect + reference-typed inputs must specify these explicitly. + original_op: Optional. Used to associate the new Operation with an + existing Operation (for example, a replica with the op that was + replicated). + op_def: Optional. The op_def_pb2.OpDef proto that describes the + op type that this Operation represents. + + Raises: + TypeError: if control inputs are not Operations or Tensors, + or if node_def is not a NodeDef, + or if g is not a Graph, + or if inputs are not Tensors, + or if inputs and input_types are incompatible. + ValueError: if the node_def name is not valid. + """ + if not isinstance(node_def, graph_pb2.NodeDef): + raise TypeError("node_def needs to be a NodeDef: %s" % node_def) + if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: + raise ValueError( + "Cannot create an Operation with a NodeDef larger than 2GB.") + if not _VALID_OP_NAME_REGEX.match(node_def.name): + raise ValueError("'%s' is not a valid node name" % node_def.name) + if not isinstance(g, Graph): + raise TypeError("g needs to be a Graph: %s" % g) + self._node_def = copy.deepcopy(node_def) + self._graph = g + if inputs is None: + inputs = [] + self._inputs = inputs + for a in self._inputs: + if not isinstance(a, Tensor): + raise TypeError("input needs to be a Tensor: %s" % a) + # Mark that we consume the inputs. + a._add_consumer(self) # pylint: disable=protected-access + if output_types is None: + output_types = [] + self._output_types = output_types + self._outputs = [Tensor(self, i, output_types[i]) + for i in xrange(len(output_types))] + if input_types is None: + input_types = [i.dtype.base_dtype for i in self._inputs] + else: + if not all(x.is_compatible_with(i.dtype) + for i, x in zip(self._inputs, input_types)): + raise TypeError("Inputs are not compatible with input types") + self._input_types = input_types + + # Build the list of control inputs. + self._control_inputs = [] + if control_inputs: + for c in control_inputs: + c_op = None + if isinstance(c, Operation): + c_op = c + elif isinstance(c, (Tensor, IndexedSlices)): + c_op = c.op + else: + raise TypeError("Control input must be an Operation, " + "a Tensor, or IndexedSlices: %s" % c) + self._control_inputs.append(c_op) + + self._original_op = original_op + self._op_def = op_def + self._traceback = _extract_stack() + # Add this op to the current control flow context: + self._control_flow_context = g._get_control_flow_context() + if g._get_control_flow_context() is not None: + g._get_control_flow_context().AddOp(self) + # NOTE(keveman): Control flow context's AddOp could be creating new ops and + # setting op.inputs[index] = new_op. Thus the new ops' id could be larger + # than this op's id even though this op depend on them. Therefore, delaying + # assigning id to this op until all ops this could be dependent on are + # created. + self._id_value = self._graph._next_id() # pylint: disable=protected-access + self._recompute_node_def() + + def values(self): + """DEPRECATED: Use outputs.""" + return tuple(self.outputs) + + def _get_control_flow_context(self): + """Returns the current control flow context. + + Returns: + A context object. + """ + return self._control_flow_context + + @property + def name(self): + """The full name of this operation.""" + return self._node_def.name + + @property + def _id(self): + """The unique integer id of this operation.""" + return self._id_value + + @property + def device(self): + """The name of the device to which this op has been assigned, if any. + + Returns: + The string name of the device to which this op has been + assigned, or None if it has not been assigned to a device. + """ + dev = self._node_def.device + return None if not dev else dev + + def _set_device(self, device): + """Set the device of this operation. + + Args: + device: string or device.. The device to set. + """ + self._node_def.device = _device_string(device) + + def _add_input(self, tensor, dtype=None): + """Add a new input to this operation. + + Args: + tensor: the Tensor to add as an input. + dtype: types.DType: type of the input; defaults to + the tensor's dtype. + + Raises: + TypeError: if tensor is not a Tensor, + or if input tensor type is not convertible to dtype. + ValueError: if the Tensor is from a different graph. + """ + if not isinstance(tensor, Tensor): + raise TypeError("tensor must be a Tensor: %s" % tensor) + assert_same_graph([self, tensor]) + if dtype is None: + dtype = tensor.dtype + else: + dtype = types.as_dtype(dtype) + if not dtype.is_compatible_with(tensor.dtype): + raise TypeError( + "Cannot convert a tensor of type %s to an input of type %s" + % (tensor.dtype.name, dtype.name)) + self._inputs.append(tensor) + self._input_types.append(dtype) + tensor._add_consumer(self) # pylint: disable=protected-access + self._recompute_node_def() + + def _update_input(self, index, tensor, dtype=None): + """Update the input to this operation at the given index. + + NOTE: This is for TF internal use only. Please don't use it. + + Args: + index: the index of the input to update. + tensor: the Tensor to be used as the input at the given index. + dtype: types.DType: type of the input; defaults to + the tensor's dtype. + + Raises: + TypeError: if tensor is not a Tensor, + or if input tensor type is not convertible to dtype. + ValueError: if the Tensor is from a different graph. + """ + if not isinstance(tensor, Tensor): + raise TypeError("tensor must be a Tensor: %s" % tensor) + assert_same_graph([self, tensor]) + if dtype is None: + dtype = tensor.dtype + else: + dtype = types.as_dtype(dtype) + if not dtype.is_compatible_with(tensor.dtype): + raise TypeError( + "Cannot convert a tensor of type %s to an input of type %s" + % (tensor.dtype.name, dtype.name)) + + self._inputs[index].consumers().remove(self) + self._inputs[index] = tensor + self._input_types[index] = dtype + tensor._add_consumer(self) # pylint: disable=protected-access + self._recompute_node_def() + + def _add_control_input(self, op): + """Add a new control input to this operation. + + Args: + op: the Operation to add as control input. + + Raises: + TypeError: if op is not an Operation. + ValueError: if op is from a different graph. + """ + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + assert_same_graph([self, op]) + self._control_inputs.append(op) + self._recompute_node_def() + + # Methods below are used when building the NodeDef and Graph proto. + def _recompute_node_def(self): + del self._node_def.input[:] + self._node_def.input.extend([t._as_node_def_input() for t in self._inputs]) + if self._control_inputs: + self._node_def.input.extend(["^%s" % op.name for op in + self._control_inputs]) + + def __str__(self): + return str(self._node_def) + + @property + def outputs(self): + """The list of `Tensor` objects representing the outputs of this op.""" + return self._outputs + +# pylint: disable=protected-access + class _InputList(object): + """Immutable input list wrapper.""" + + def __init__(self, op): + self._op = op + + def __iter__(self): + return iter(self._op._inputs) + + def __len__(self): + return len(self._op._inputs) + + def __bool__(self): + return bool(self._op._inputs) + + def __getitem__(self, i): + return self._op._inputs[i] +# pylint: enable=protected-access + + @property + def inputs(self): + """The list of `Tensor` objects representing the data inputs of this op.""" + return Operation._InputList(self) + + @property + def _input_dtypes(self): + return self._input_types + + @property + def control_inputs(self): + """The `Operation` objects on which this op has a control dependency. + + Before this op is executed, TensorFlow will ensure that the + operations in `self.control_inputs` have finished executing. This + mechanism can be used to run ops sequentially for performance + reasons, or to ensure that the side effects of an op are observed + in the correct order. + + Returns: + A list of `Operation` objects. + + """ + return self._control_inputs + + @property + def type(self): + """The type of the op (e.g. `"MatMul"`).""" + return self._node_def.op + + @property + def graph(self): + """The `Graph` that contains this operation.""" + return self._graph + + @property + def node_def(self): + """Returns a serialized `NodeDef` representation of this operation. + + Returns: + A + [`NodeDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) + protocol buffer. + """ + return self._node_def + + @property + def op_def(self): + """Returns the `OpDef` proto that represents the type of this op. + + Returns: + An + [`OpDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def.proto) + protocol buffer. + """ + return self._op_def + + @property + def traceback(self): + """Returns the call stack from when this operation was constructed.""" + return _convert_stack(self._traceback) + + def get_attr(self, name): + """Returns the value of the attr of this op with the given `name`. + + Args: + name: The name of the attr to fetch. + + Returns: + The value of the attr, as a Python object. + + Raises: + ValueError: If this op does not have an attr with the given `name`. + """ + fields = ["s", "i", "f", "b", "type", "shape", "tensor"] + if name not in self._node_def.attr: + raise ValueError("No attr named '" + name + "' in " + + str(self._node_def)) + x = self._node_def.attr[name] + # Treat an empty oneof value as an empty list. + if not x.WhichOneof("value"): + return [] + if x.HasField("list"): + for f in fields: + if getattr(x.list, f): + return list(getattr(x.list, f)) + return [] + else: + for f in fields: + if x.HasField(f): + return getattr(x, f) + assert False, "Unsupported field type in " + str(x) + + def run(self, feed_dict=None, session=None): + """Runs this operation in a `Session`. + + Calling this method will execute all preceding operations that + produce the inputs needed for this operation. + + *N.B.* Before invoking `Operation.run()`, its graph must have been + launched in a session, and either a default session must be + available, or `session` must be specified explicitly. + + Args: + feed_dict: A dictionary that maps `Tensor` objects to feed values. + See [`Session.run()`](client.md#Session.run) for a description of the + valid feed values. + session: (Optional.) The `Session` to be used to run to this operation. If + none, the default session will be used. + """ + _run_using_default_session(self, feed_dict, self.graph, session) + + +_gradient_registry = registry.Registry("gradient") + + +class RegisterGradient(object): + """A decorator for registering the gradient function for an op type. + + This decorator is only used when defining a new op type. For an op + with `m` inputs and `n` inputs, the gradient function is a function + that takes the original `Operation` and `n` `Tensor` objects + (representing the gradients with respect to each output of the op), + and returns `m` `Tensor` objects (representing the partial gradients + with respect to each input of the op). + + For example, assuming that operations of type `"Sub"` take two + inputs `x` and `y`, and return a single output `x - y`, the + following gradient function would be registered: + + ```python + @tf.RegisterGradient("Sub") + def _sub_grad(unused_op, grad): + return grad, tf.Neg(grad) + ``` + + The decorator argument `op_type` is the string type of an + operation. This corresponds to the `OpDef.name` field for the proto + that defines the operation. + + @@__init__ + """ + + def __init__(self, op_type): + """Creates a new decorator with `op_type` as the Operation type. + + Args: + op_type: The string type of an operation. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + """ + if not isinstance(op_type, basestring): + raise TypeError("op_type must be a string") + self._op_type = op_type + + def __call__(self, f): + """Registers the function `f` as gradient function for `op_type`.""" + _gradient_registry.register(f, self._op_type) + return f + + +def NoGradient(op_type): + """Specifies that ops of type `op_type` do not have a defined gradient. + + This function is only used when defining a new op type. It may be + used for ops such as `tf.size()` that are not differentiable. For + example: + + ```python + tf.NoGradient("Size") + ``` + + Args: + op_type: The string type of an operation. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + + Raises: + TypeError: If `op_type` is not a string. + + """ + if not isinstance(op_type, basestring): + raise TypeError("op_type must be a string") + _gradient_registry.register(None, op_type) + + +def get_gradient_function(op): + """Returns the function that computes gradients for "op".""" + if not op.inputs: return None + try: + op_type = op.get_attr("_gradient_op_type") + except ValueError: + op_type = op.type + return _gradient_registry.lookup(op_type) + + +_shape_registry = registry.Registry("shape functions") +_default_shape_function_registry = registry.Registry("default shape functions") + +class RegisterShape(object): + """A decorator for registering the shape function for an op type. + + This decorator is only used when defining a new op type. A shape + function is a function from an `Operation` object to a list of + `TensorShape` objects, with one `TensorShape` for each output of the + operation. + + For example, assuming that operations of type `"Sub"` take two + inputs `x` and `y`, and return a single output `x - y`, all with the + same shape, the following shape function would be registered: + + ```python + @tf.RegisterShape("Sub") + def _sub_shape(op): + return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())] + ``` + + The decorator argument `op_type` is the string type of an + operation. This corresponds to the `OpDef.name` field for the proto + that defines the operation. + + """ + + def __init__(self, op_type): + """Saves the "op_type" as the Operation type.""" + if not isinstance(op_type, basestring): + raise TypeError("op_type must be a string") + self._op_type = op_type + + def __call__(self, f): + """Registers "f" as the shape function for "op_type".""" + if f is None: + # None is a special "weak" value that provides a default shape function, + # and can be overridden by a non-None registration. + try: + _default_shape_function_registry.register(_no_shape_function, + self._op_type) + except KeyError: + # Ignore duplicate registrations of the weak value. This can + # occur if the op library input to wrapper generation + # inadvertently links in one or more of the standard op + # libraries. + pass + else: + _shape_registry.register(f, self._op_type) + return f + + +def _no_shape_function(op): + return [tensor_shape.unknown_shape() for _ in op.outputs] + + +def set_shapes_for_outputs(op): + """Uses the registered shape functions to set the shapes for op's outputs.""" + try: + shape_func = _shape_registry.lookup(op.type) + except LookupError: + try: + shape_func = _default_shape_function_registry.lookup(op.type) + except LookupError: + raise RuntimeError("No shape function registered for standard op: %s" + % op.type) + shapes = shape_func(op) + if len(op.outputs) != len(shapes): + raise RuntimeError( + "Shape function for op %s returned %g shapes but expecting %g" % + (op, len(op.outputs), len(shapes))) + for output, s in zip(op.outputs, shapes): + output.set_shape(s) + + +class Graph(object): + """A TensorFlow computation, represented as a dataflow graph. + + A `Graph` contains a set of [`Operation`](framework.md#Operation) objects, + which represent units of computation; and [`Tensor`](framework.md#Tensor) + objects, which represent the units of data that flow between operations. + + A default `Graph` is always registered, and accessible by calling + [`tf.get_default_graph()`](framework.md#get_default_graph). To add an + operation to the default graph, simply call one of the functions that defines + a new `Operation`: + + ``` + c = tf.constant(4.0) + assert c.graph is tf.get_default_graph() + ``` + + Another typical usage involves the + [`Graph.as_default()`](framework.md#Graph.as_default) + context manager, which overrides the current default graph for the + lifetime of the context: + + ```python + g = tf.Graph() + with g.as_default(): + # Define operations and tensors in `g`. + c = tf.constant(30.0) + assert c.graph is g + ``` + + Important note: This class *is not* thread-safe for graph construction. All + operations should be created from a single thread, or external + synchronization must be provided. Unless otherwise specified, all methods + are not thread-safe. + + @@__init__ + @@as_default + @@as_graph_def + @@finalize + @@finalized + + @@control_dependencies + @@device + @@name_scope + + A `Graph` instance supports an arbitrary number of "collections" + that are identified by name. For convenience when building a large + graph, collections can store groups of related objects: for + example, the `tf.Variable` uses a collection (named + [`tf.GraphKeys.VARIABLES`](framework.md#GraphKeys)) for all variables that are + created during the construction of a graph. The caller may define + additional collections by specifying a new name. + + @@add_to_collection + @@get_collection + + @@as_graph_element + @@get_operation_by_name + @@get_tensor_by_name + @@get_operations + + @@get_default_device + @@seed + @@unique_name + @@version + + @@create_op + @@gradient_override_map + """ + + def __init__(self): + """Creates a new, empty Graph.""" + self._nodes_by_id = dict() + self._next_node_id = [dict()] + self._next_id_counter = 0 + self._nodes_by_name = dict() + # Current name stack: a pair of uniquified names and plain names. + self._name_stack = ("", "") + # Maps a name used in the graph to the next id to use for that name. + self._names_in_use = {} + # Default device applied to new ops. + self._default_device = None + # Functions that will be applied to choose a device if none is specified. + self._device_function_stack = [] + # Default original_op applied to new ops. + self._default_original_op = None + # Current control flow context. It could be either CondContext or + # WhileContext defined in ops/control_flow_ops.py + self._control_flow_context = None + # A new node will depend of the union of all of the nodes in the stack. + self._control_dependencies_stack = [] + # Arbritrary collections of objects. + self._collections = {} + # The graph-level random seed + self._seed = None + # A map from op type to the kernel label that should be used. + self._op_to_kernel_label_map = {} + # A map from op type to an alternative op type that should be used when + # computing gradients. + self._gradient_override_map = {} + # True if the graph is considered "finalized". In that case no + # new operations can be added. + self._finalized = False + + def _check_not_finalized(self): + """Check if the graph is finalized. + + Raises: + RuntimeError: If the graph finalized. + """ + if self._finalized: + raise RuntimeError("Graph is finalized and cannot be modified.") + + def _add_op(self, op): + """Adds 'op' to the graph. + + Args: + op: the Operator or Tensor to add. + + Raises: + TypeError: if op is not an Operation or Tensor. + ValueError: if the op.name or op._id are already used. + """ + self._check_not_finalized() + if not isinstance(op, (Tensor, Operation)): + raise TypeError("op must be a Tensor or Operation: %s" % op) + + if op._id in self._nodes_by_id: + raise ValueError("cannot add an op with id %d as it already " + "exists in the graph" % op._id) + if op.name in self._nodes_by_name: + raise ValueError("cannot add op with name %s as that name " + "is already used" % op.name) + self._nodes_by_id[op._id] = op + self._nodes_by_name[op.name] = op + + @property + def version(self): + """Returns a version number that increases as ops are added to the graph.""" + return self._next_id_counter + + @property + def seed(self): + return self._seed + + @seed.setter + def seed(self, seed): + self._seed = seed + + @property + def finalized(self): + """True if this graph has been finalized.""" + return self._finalized + + def finalize(self): + """Finalizes this graph, making it read-only. + + After calling `g.finalize()`, no new operations can be added to + `g`. This method is used to ensure that no operations are added + to a graph when it is shared between multiple threads, for example + when using a [`QueueRunner`](train.md#QueueRunner). + """ + self._finalized = True + + def _get_control_flow_context(self): + """Returns the current control flow context. + + Returns: + A context object. + """ + return self._control_flow_context + + def _set_control_flow_context(self, context): + """Sets the current control flow context. + + Args: + context: a context object. + """ + self._control_flow_context = context + + def as_graph_def(self, from_version=None): + """Returns a serialized `GraphDef` representation of this graph. + + This method is thread-safe. + + Args: + from_version: Optional. If this is set, returns a `GraphDef` + containing only the nodes that were added to this graph since + its `version` property had the given value. + + Returns: + A + [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) + protocol buffer. + """ + graph = graph_pb2.GraphDef() + bytesize = 0 + for op_id in sorted(self._nodes_by_id): + op = self._nodes_by_id[op_id] + if from_version is None or op_id > from_version: + graph.node.extend([op.node_def]) + bytesize += op.node_def.ByteSize() + if bytesize >= (1 << 31) or bytesize < 0: + raise ValueError("GraphDef cannot be larger than 2GB.") + return graph + + # Helper functions to create operations. + def create_op(self, op_type, inputs, dtypes, + input_types=None, name=None, attrs=None, op_def=None, + compute_shapes=True): + """Creates an `Operation` in this graph. + + This is a low-level interface for creating an `Operation`. Most + programs will not call this method directly, and instead use the + Python op constructors, such as `tf.constant()`, which add ops to + the default graph. + + Args: + op_type: The `Operation` type to create. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + inputs: A list of `Tensor` objects that will be inputs to the `Operation`. + dtypes: A list of `DType` objects that will be the types of the tensors + that the operation produces. + input_types: (Optional.) A list of `DType`s that will be the types of + the tensors that the operation consumes. By default, uses the base + `DType` of each input in `inputs`. Operations that expect + reference-typed inputs must specify `input_types` explicitly. + name: (Optional.) A string name for the operation. If not specified, a + name is generated based on `op_type`. + attrs: (Optional.) A list of `AttrValue` protos for the `attr` field of + the `NodeDef` proto that will represent the operation. + op_def: (Optional.) The `OpDef` proto that describes the `op_type` that + the operation will have. + compute_shapes: (Optional.) If True, shape inference will be performed + to compute the shapes of the outputs. + + Raises: + TypeError: if any of the inputs is not a `Tensor`. + + Returns: + An `Operation` object. + + """ + self._check_not_finalized() + for idx, a in enumerate(inputs): + if not isinstance(a, Tensor): + raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) + if name is None: + name = op_type + # If a names ends with a '/' it is a "name scope" and we use it as-is, + # after removing the trailing '/'. + if name and name[-1] == "/": + name = name[:-1] + else: + name = self.unique_name(name) + + node_def = _NodeDef( + op_type, name, device=self._default_device or None, attrs=attrs) + + # Apply a kernel label if one has been specified for this op_type. + try: + kernel_label = self._op_to_kernel_label_map[op_type] + node_def.attr["_kernel"].CopyFrom( + attr_value_pb2.AttrValue(s=kernel_label)) + except KeyError: + pass + + # Apply the overriding op_type for gradients if one has been + # specified for this op_type. + try: + mapped_op_type = self._gradient_override_map[op_type] + node_def.attr["_gradient_op_type"].CopyFrom( + attr_value_pb2.AttrValue(s=mapped_op_type)) + except KeyError: + pass + + control_inputs = self._control_dependencies_for_inputs(inputs) + ret = Operation(node_def, self, inputs=inputs, output_types=dtypes, + control_inputs=control_inputs, input_types=input_types, + original_op=self._default_original_op, op_def=op_def) + if compute_shapes: + set_shapes_for_outputs(ret) + self._add_op(ret) + self._record_op_seen_by_control_dependencies(ret) + # Apply any device functions in reverse order, so that the most recently + # pushed function has the first chance to apply a device to the op. + # We apply here because the result can depend on the Operation's + # signature, which is computed in the Operation constructor. + for device_function in reversed(self._device_function_stack): + ret._set_device(device_function(ret)) + return ret + + def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): + """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. + + This function validates that `obj` represents an element of this + graph, and gives an informative error message if it is not. + + This function is the canonical way to get/validate an object of + one of the allowed types from an external argument reference in the + Session API. + + This method may be called concurrently from multiple threads. + + Args: + obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. + Can also be any object with an `_as_graph_element()` method that returns + a value of one of these types. + allow_tensor: If true, `obj` may refer to a `Tensor`. + allow_operation: If true, `obj` may refer to an `Operation`. + + Returns: + The `Tensor` or `Operation` in the Graph corresponding to `obj`. + + Raises: + TypeError: If `obj` is not a type we support attempting to convert + to types. + ValueError: If `obj` is of an appropriate type but invalid. For + example, an invalid string. + KeyError: If `obj` is not an object in the graph. + """ + + # The vast majority of this function is figuring + # out what an API user might be doing wrong, so + # that we can give helpful error messages. + # + # Ideally, it would be nice to split it up, but we + # need context to generate nice error messages. + + if allow_tensor and allow_operation: + types_str = "Tensor or Operation" + elif allow_tensor: + types_str = "Tensor" + elif allow_operation: + types_str = "Operation" + else: + raise ValueError("allow_tensor and allow_operation can't both be False.") + + conv_fn = getattr(obj, "_as_graph_element", None) + if conv_fn and callable(conv_fn): + obj = conv_fn() + + # If obj appears to be a name... + if isinstance(obj, basestring): + name = obj + + if ":" in name and allow_tensor: + # Looks like a Tensor name and can be a Tensor. + try: + op_name, out_n = name.split(":") + out_n = int(out_n) + except: + raise ValueError("The name %s looks a like a Tensor name, but is " + "not a valid one. Tensor names must be of the " + "form \"<op_name>:<output_index>\"." % repr(name)) + if op_name in self._nodes_by_name: + op = self._nodes_by_name[op_name] + else: + raise KeyError("The name %s refers to a Tensor which does not " + "exist. The operation, %s, does not exist in the " + "graph." % (repr(name), repr(op_name))) + try: + return op.outputs[out_n] + except: + raise KeyError("The name %s refers to a Tensor which does not " + "exist. The operation, %s, exists but only has " + "%s outputs." + % (repr(name), repr(op_name), len(op.outputs))) + + elif ":" in name and not allow_tensor: + # Looks like a Tensor name but can't be a Tensor. + raise ValueError("Name %s appears to refer to a Tensor, not a %s." + % (repr(name), types_str)) + + elif ":" not in name and allow_operation: + # Looks like an Operation name and can be an Operation. + if name not in self._nodes_by_name: + raise KeyError("The name %s refers to an Operation not in the " + "graph." % repr(name)) + return self._nodes_by_name[name] + + elif ":" not in name and not allow_operation: + # Looks like an Operation name but can't be an Operation. + if name in self._nodes_by_name: + # Yep, it's an Operation name + err_msg = ("The name %s refers to an Operation, not a %s." + % (repr(name), types_str)) + else: + err_msg = ("The name %s looks like an (invalid) Operation name, " + "not a %s." % (repr(name), types_str)) + err_msg += (" Tensor names must be of the form " + "\"<op_name>:<output_index>\".") + raise ValueError(err_msg) + + elif isinstance(obj, Tensor) and allow_tensor: + # Actually obj is just the object it's referring to. + return obj + elif isinstance(obj, Operation) and allow_operation: + # Actually obj is just the object it's referring to. + return obj + else: + # We give up! + raise TypeError("Can not convert a %s into a %s." + % (type(obj).__name__, types_str)) + + def get_operations(self): + """Return the list of operations in the graph. + + You can modify the operations in place, but modifications + to the list such as inserts/delete have no effect on the + list of operations known to the graph. + + This method may be called concurrently from multiple threads. + + Returns: + A list of Operations. + """ + return self._nodes_by_id.values() + + def get_operation_by_name(self, name): + """Returns the `Operation` with the given `name`. + + This method may be called concurrently from multiple threads. + + Args: + name: The name of the `Operation` to return. + + Returns: + The `Operation` with the given `name`. + + Raises: + TypeError: If `name` is not a string. + KeyError: If `name` does not correspond to an operation in this graph. + """ + + if not isinstance(name, basestring): + raise TypeError("Operation names are strings (or similar), not %s." + % type(name).__name__) + return self.as_graph_element(name, allow_tensor=False, allow_operation=True) + + def get_tensor_by_name(self, name): + """Returns the `Tensor` with the given `name`. + + This method may be called concurrently from multiple threads. + + Args: + name: The name of the `Tensor` to return. + + Returns: + The `Tensor` with the given `name`. + + Raises: + TypeError: If `name` is not a string. + KeyError: If `name` does not correspond to a tensor in this graph. + """ + # Names should be strings. + if not isinstance(name, basestring): + raise TypeError("Tensor names are strings (or similar), not %s." + % type(name).__name__) + return self.as_graph_element(name, allow_tensor=True, allow_operation=False) + + def _next_id(self): + """Id for next Operation instance. Also increments the internal id.""" + self._check_not_finalized() + self._next_id_counter += 1 + return self._next_id_counter + + @property + def _last_id(self): + return self._next_id_counter + + def as_default(self): + """Returns a context manager that makes this `Graph` the default graph. + + This method should be used if you want to create multiple graphs + in the same process. For convenience, a global default graph is + provided, and all ops will be added to this graph if you do not + create a new graph explicitly. Use this method the `with` keyword + to specify that ops created within the scope of a block should be + added to this graph. + + The default graph is a property of the current thread. If you + create a new thread, and wish to use the default graph in that + thread, you must explicitly add a `with g.as_default():` in that + thread's function. + + The following code examples are equivalent: + + ```python + # 1. Using Graph.as_default(): + g = tf.Graph() + with g.as_default(): + c = tf.constant(5.0) + assert c.graph is g + + # 2. Constructing and making default: + with tf.Graph().as_default() as g: + c = tf.constant(5.0) + assert c.graph is g + ``` + + Returns: + A context manager for using this graph as the default graph. + """ + return _default_graph_stack.get_controller(self) + + def add_to_collection(self, name, value): + """Stores `value` in the collection with the given `name`. + + Args: + name: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + value: The value to add to the collection. + """ + self._check_not_finalized() + if name not in self._collections: + self._collections[name] = [value] + else: + self._collections[name].append(value) + + def get_collection(self, name, scope=None): + """Returns a list of values in the collection with the given `name`. + + Args: + key: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + scope: (Optional.) If supplied, the resulting list is filtered to include + only items whose name begins with this string. + + Returns: + The list of values in the collection with the given `name`, or + an empty list if no value has been added to that collection. The + list contains the values in the order under which they were + collected. + """ + if scope is None: + return self._collections.get(name, list()) + else: + c = [] + for item in self._collections.get(name, list()): + if hasattr(item, 'name') and item.name.startswith(scope): + c.append(item) + return c + + @contextlib.contextmanager + def _original_op(self, op): + """Python 'with' handler to help annotate ops with their originator. + + An op may have an 'original_op' property that indicates the op on which + it was based. For example a replica op is based on the op that was + replicated and a gradient op is based on the op that was differentiated. + + All ops created in the scope of this 'with' handler will have + the given 'op' as their original op. + + Args: + op: The Operation that all ops created in this scope will have as their + original op. + + Yields: + Nothing. + """ + old_original_op = self._default_original_op + try: + self._default_original_op = op + yield + finally: + self._default_original_op = old_original_op + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def name_scope(self, name): + """Returns a context manager that creates hierarchical names for operations. + + A graph maintains a stack of name scopes. A `with name_scope(...):` + statement pushes a new name onto the stack for the lifetime of the context. + + The `name` argument will be interpreted as follows: + + * A string (not ending with '/') will create a new name scope, in which + `name` is appended to the prefix of all operations created in the + context. If `name` has been used before, it will be made unique by + calling `self.unique_name(name)`. + * A scope previously captured from a `with g.name_scope(...) as + scope:` statement will be treated as an "absolute" name scope, which + makes it possible to re-enter existing scopes. + * A value of `None` or the empty string will reset the current name scope + to the top-level (empty) name scope. + + For example: + + ```python + with tf.Graph().as_default() as g: + c = tf.constant(5.0, name="c") + assert c_1.name == "c" + c_1 = tf.constant(6.0, name="c") + assert c_1.name == "c_1" + + # Creates a scope called "nested" + with g.name_scope("nested") as scope: + nested_c = tf.constant(10.0, name="c") + assert nested_c.name == "nested/c" + + # Creates a nested scope called "inner". + with g.name_scope("inner"): + nested_inner_c = tf.constant(20.0, name="c") + assert nested_inner_c.name == "nested/inner/c" + + # Create a nested scope called "inner_1". + with g.name_scope("inner"): + nested_inner_1_c = tf.constant(30.0, name="c") + assert nested_inner_1_c.name == "nested/inner_1/c" + + # Treats `scope` as an absolute name scope, and + # switches to the "nested/" scope. + with g.name_scope(scope): + nested_d = tf.constant(40.0, name="d") + assert nested_d.name == "nested/d" + + with g.name_scope(""): + e = tf.constant(50.0, name="e") + assert e.name == "e" + ``` + + The name of the scope itself can be captured by `with + g.name_scope(...) as scope:`, which stores the name of the scope + in the variable `scope`. This value can be used to name an + operation that represents the overall result of executing the ops + in a scope. For example: + + ```python + inputs = tf.constant(...) + with g.name_scope('my_layer') as scope: + weights = tf.Variable(..., name="weights") + biases = tf.Variable(..., name="biases") + affine = tf.matmul(inputs, weights) + biases + output = tf.nn.relu(affine, name=scope) + ``` + + + Args: + name: A name for the scope. + + Returns: + A context manager that installs `name` as a new name scope. + """ + try: + old_stack = self._name_stack + if not name: # Both for name=None nad name="" we re-set to empty scope. + new_stack = (None, None) + elif name and name[-1] == "/": + new_stack = (name[:-1], name[:-1]) + else: + new_stack = (self.unique_name(name), self._plain_name(name)) + self._name_stack = new_stack + yield "" if new_stack[0] is None else new_stack[0] + "/" + finally: + self._name_stack = old_stack + # pylint: enable=g-doc-return-or-yield + + def unique_name(self, name): + """Return a unique Operation name for "name". + + Note: You rarely need to call unique_name() directly. Most of the time you + just need to create "with g.name_scope()" blocks to generate structured + names. + + `unique_name` is used to generate structured names, separated by "/", + to help identify Operations when debugging a Graph. Operation names + are displayed in error messages reported by the TensorFlow runtime, + and in various visualization tools such as TensorBoard. + + Args: + name: The name for an `Operation`. + + Returns: + A string to be passed to `create_op()` that will be used + to name the operation being created. + """ + if self._name_stack[0]: + name = self._name_stack[0] + "/" + name + i = self._names_in_use.get(name, 0) + # Increment the number for "name". + self._names_in_use[name] = i + 1 + if i > 0: + base_name = name + # Make sure the composed name is not already used. + while name in self._names_in_use: + name = "%s_%d" % (base_name, i) + i += 1 + # Mark the composed name as used in case someone wants + # to call unique_name("name_1"). + self._names_in_use[name] = 1 + return name + + # TODO(mdevin): remove + def _plain_name(self, name): + """Return the fully scoped 'name'. + + Args: + name: a string. + + Returns: + 'name' scoped in the current name stack, without any uniquified + elements. + """ + if self._name_stack[1]: + return self._name_stack[1] + "/" + name + else: + return name + + def _set_default_device(self, dev): + """Set the default device properties. + + Args: + dev: string or Device. + """ + self._default_device = _device_string(dev) + + def get_default_device(self): + """Returns the default device. + + Returns: + A string. + """ + return self._default_device + + def _push_default_device_function(self, device_function): + """Pushes the given function onto the stack of device functions. + + See Graph.device for more details. + + Args: + device_function: The function to be pushed onto the stack of device + functions. + """ + self._device_function_stack.append(device_function) + + def _pop_default_device_function(self, device_function): + """Pops the given function from the stack of device functions. + + See Graph.device for more details. + + Args: + device_function: The function to be popped from the stack of device + functions. + + Raises: + ValueError: if the device_function to be popped is not top of the stack, + or if the stack is empty. + """ + if not self._device_function_stack: + raise ValueError("Tried to pop, but the device function stack is empty") + if self._device_function_stack[-1] is not device_function: + raise ValueError("Tried to pop device function, but it was not on top " + "of the stack") + + self._device_function_stack.pop() + + @contextlib.contextmanager + def device(self, device_name_or_function): + """Returns a context manager that specifies the default device to use. + + The `device_name_or_function` argument may either be a device name + string, a device function, or None: + + * If it is a device name string, all operations constructed in + this context will be assigned to the device with that name. + * If it is a function, it will be treated as function from + Operation objects to device name strings, and invoked each time + a new Operation is created. The Operation will be assigned to + the device with the returned name. + * If it is None, the default device will be cleared. + + For example: + + ```python + with g.device('/gpu:0'): + # All operations constructed in this context will be placed + # on GPU 0. + with g.device(None): + # All operations constructed in this context will have no + # assigned device. + + # Defines a function from `Operation` to device string. + def matmul_on_gpu(n): + if n.type == "MatMul": + return "/gpu:0" + else: + return "/cpu:0" + + with g.device(matmul_on_gpu): + # All operations of type "MatMul" constructed in this context + # will be placed on GPU 0; all other operations will be placed + # on CPU 0. + ``` + + Args: + device_name_or_function: The device name or function to use in + the context. + + Returns: + A context manager that specifies the default device to use for newly + created ops. + """ + if callable(device_name_or_function): + try: + self._push_default_device_function(device_name_or_function) + yield + finally: + self._pop_default_device_function(device_name_or_function) + else: + try: + old_dev = self.get_default_device() + self._set_default_device(_device_string(device_name_or_function)) + yield + finally: + self._set_default_device(old_dev) + + class _ControlDependenciesController(object): + """Context manager for `control_dependencies()`.""" + + def __init__(self, graph, control_inputs): + self._graph = graph + self._control_inputs = control_inputs + self._seen_nodes = set() + +# pylint: disable=protected-access + def __enter__(self): + self._graph._push_control_dependencies_controller(self) + + def __exit__(self, unused_type, unused_value, unused_traceback): + self._graph._pop_control_dependencies_controller(self) +# pylint: enable=protected-access + + @property + def control_inputs(self): + return self._control_inputs + + def add_op(self, op): + self._seen_nodes.add(op) + + def op_in_group(self, op): + return op in self._seen_nodes + + def _push_control_dependencies_controller(self, controller): + self._control_dependencies_stack.append(controller) + + def _pop_control_dependencies_controller(self, controller): + assert self._control_dependencies_stack[-1] is controller + self._control_dependencies_stack.pop() + + def _current_control_dependencies(self): + ret = set() + for controller in self._control_dependencies_stack: + for op in controller.control_inputs: + ret.add(op) + return ret + + def _control_dependencies_for_inputs(self, input_tensors): + """For an op that takes `input_tensors` as inputs, compute control inputs. + + The returned control dependencies should yield an execution that + is equivalent to adding all control inputs in + self._control_dependencies_stack to a newly created op. However, + this function attempts to prune the returned control dependencies + by observing that nodes created within the same `with + control_dependencies(...):` block may have data dependencies that make + the explicit approach redundant. + + Args: + input_tensors: The direct data dependencies for an op to be created. + + Returns: + A list of control inputs for the op to be created. + """ + ret = [] + input_ops = set([t.op for t in input_tensors]) + for controller in self._control_dependencies_stack: + # If any of the input_ops already depends on the inputs from controller, + # we say that the new op is dominated (by that input), and we therefore + # do not need to add control dependences for this controller's inputs. + dominated = False + for op in input_ops: + if controller.op_in_group(op): + dominated = True + break + if not dominated: + # Don't add a control input if we already have a data dependency on i. + # NOTE(mrry): We do not currently track transitive data dependencies, + # so we may add redundant control inputs. + ret.extend([c for c in controller.control_inputs if c not in input_ops]) + return ret + + def _record_op_seen_by_control_dependencies(self, op): + """Record that the given op depends on all registered control dependencies. + + Args: + op: An Operation. + """ + for controller in self._control_dependencies_stack: + controller.add_op(op) + + def control_dependencies(self, control_inputs): + """Returns a context manager that specifies control dependencies. + + Use with the `with` keyword to specify that all operations constructed + within the context should have control dependencies on + `control_inputs`. For example: + + ```python + with g.control_dependencies([a, b, c]): + # `d` and `e` will only run after `a`, `b`, and `c` have executed. + d = ... + e = ... + ``` + + Multiple calls to `control_dependencies()` can be nested, and in + that case a new `Operation` will have control dependencies on the union + of `control_inputs` from all active contexts. + + ```python + with g.control_dependencies([a, b]): + # Ops declared here run after `a` and `b`. + with g.control_dependencies([c, d]): + # Ops declared here run after `a`, `b`, `c`, and `d`. + ``` + + *N.B.* The control dependencies context applies *only* to ops that + are constructed within the context. Merely using an op or tensor + in the context does not add a control dependency. The following + example illustrates this point: + + ```python + # WRONG + def my_func(pred, tensor): + t = tf.matmul(tensor, tensor) + with tf.control_dependencies([pred]): + # The matmul op is created outside the context, so no control + # dependency will be added. + return t + + # RIGHT + def my_func(pred, tensor): + with tf.control_dependencies([pred]): + # The matmul op is created in the context, so a control dependency + # will be added. + return tf.matmul(tensor, tensor) + ``` + + Args: + control_inputs: A list of `Operation` or `Tensor` objects, which + must be executed or computed before running the operations + defined in the context. + + Returns: + A context manager that specifies control dependencies for all + operations constructed within the context. + + Raises: + TypeError: If `control_inputs` is not a list of `Operation` or + `Tensor` objects. + """ + # First convert the inputs to ops, and deduplicate them. + # NOTE(mrry): Other than deduplication, we do not currently track direct + # or indirect dependencies between control_inputs, which may result in + # redundant control inputs. + control_ops = [] + current = self._current_control_dependencies() + for c in control_inputs: + if isinstance(c, Tensor): + c = c.op + elif not isinstance(c, Operation): + raise TypeError("Control input must be Operation or Tensor: %s" % c) + if c not in current: + control_ops.append(c) + current.add(c) + return self._ControlDependenciesController(self, control_ops) + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def _kernel_label_map(self, op_to_kernel_label_map): + """EXPERIMENTAL: A context manager for setting kernel labels. + + This context manager can be used to select particular + implementations of kernels within the scope of the context. + + For example: + + with ops.Graph().as_default() as g: + f_1 = Foo() # Uses the default registered kernel for the Foo op. + with g.kernel_label_map({"Foo": "v_2"}): + f_2 = Foo() # Uses the registered kernel with label "v_2" + # for the Foo op. + with g.kernel_label_map({"Foo": "v_3"}): + f_3 = Foo() # Uses the registered kernel with label "v_3" + # for the Foo op. + with g.kernel_label_map({"Foo": ""}): + f_4 = Foo() # Uses the default registered kernel + # for the Foo op. + + Args: + op_to_kernel_label_map: A dictionary mapping op type strings to + kernel label strings. + + Returns: + A context manager that sets the kernel label to be used for one or more + ops created in that context. + + Raises: + TypeError: If op_to_kernel_label_map is not a dictionary mapping + strings to strings. + """ + if not isinstance(op_to_kernel_label_map, dict): + raise TypeError("op_to_kernel_label_map must be a dictionary mapping " + "strings to strings") + # The saved_labels dictionary stores any currently-set labels that + # will be overridden by this context manager. + saved_labels = {} + # Install the given label + for op_type, label in op_to_kernel_label_map.items(): + if not (isinstance(op_type, basestring) + and isinstance(label, basestring)): + raise TypeError("op_to_kernel_label_map must be a dictionary mapping " + "strings to strings") + try: + saved_labels[op_type] = self._op_to_kernel_label_map[op_type] + except KeyError: + pass + self._op_to_kernel_label_map[op_type] = label + try: + yield # The code within the context runs here. + finally: + # Remove the labels set for this context, and restore any saved labels. + for op_type, label in op_to_kernel_label_map.items(): + try: + self._op_to_kernel_label_map[op_type] = saved_labels[op_type] + except KeyError: + del self._op_to_kernel_label_map[op_type] + # pylint: enable=g-doc-return-or-yield + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def gradient_override_map(self, op_type_map): + """EXPERIMENTAL: A context manager for overriding gradient functions. + + This context manager can be used to override the gradient function + that will be used for ops within the scope of the context. + + For example: + + ```python + @tf.RegisterGradient("CustomSquare") + def _custom_square_grad(op, inputs): + # ... + + with tf.Graph().as_default() as g: + c = tf.constant(5.0) + s_1 = tf.square(c) # Uses the default gradient for tf.square. + with g.gradient_override_map({"Square": "CustomSquare"}): + s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the + # gradient of s_2. + ``` + + Args: + op_type_map: A dictionary mapping op type strings to alternative op + type strings. + + Returns: + A context manager that sets the alternative op type to be used for one + or more ops created in that context. + + Raises: + TypeError: If `op_type_map` is not a dictionary mapping strings to + strings. + """ + if not isinstance(op_type_map, dict): + raise TypeError("op_type_map must be a dictionary mapping " + "strings to strings") + # The saved_mappings dictionary stores any currently-set mappings that + # will be overridden by this context manager. + saved_mappings = {} + # Install the given label + for op_type, mapped_op_type in op_type_map.items(): + if not (isinstance(op_type, basestring) + and isinstance(mapped_op_type, basestring)): + raise TypeError("op_type_map must be a dictionary mapping " + "strings to strings") + try: + saved_mappings[op_type] = self._gradient_override_map[op_type] + except KeyError: + pass + self._gradient_override_map[op_type] = mapped_op_type + try: + yield # The code within the context runs here. + finally: + # Remove the labels set for this context, and restore any saved labels. + for op_type, mapped_op_type in op_type_map.items(): + try: + self._gradient_override_map[op_type] = saved_mappings[op_type] + except KeyError: + del self._gradient_override_map[op_type] + # pylint: enable=g-doc-return-or-yield + + +def device(dev): + """Wrapper for `Graph.device()` using the default graph. + + See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details. + + Args: + device_name_or_function: The device name or function to use in + the context. + + Returns: + A context manager that specifies the default device to use for newly + created ops. + """ + return get_default_graph().device(dev) + + +def name_scope(name): + """Wrapper for `Graph.name_scope()` using the default graph. + + See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details. + + Args: + name: A name for the scope. + + Returns: + A context manager that installs `name` as a new name scope in the + default graph. + """ + return get_default_graph().name_scope(name) + + +def control_dependencies(control_inputs): + """Wrapper for `Graph.control_dependencies()` using the default graph. + + See [`Graph.control_dependencies()`](framework.md#Graph.control_dependencies) + for more details. + + Args: + control_inputs: A list of `Operation` or `Tensor` objects, which + must be executed or computed before running the operations + defined in the context. + + Returns: + A context manager that specifies control dependencies for all + operations constructed within the context. + """ + return get_default_graph().control_dependencies(control_inputs) + + +class _DefaultStack(threading.local): + """A thread-local stack of objects for providing implicit defaults.""" + + def __init__(self): + super(_DefaultStack, self).__init__() + self.stack = [] + + def get_default(self): + return self.stack[-1] if len(self.stack) >= 1 else None + + def reset(self): + self.stack = [] + + @contextlib.contextmanager + def get_controller(self, default): + """A context manager for manipulating a default stack.""" + try: + self.stack.append(default) + yield default + finally: + assert self.stack[-1] is default + self.stack.pop() + + +_default_session_stack = _DefaultStack() + + +def default_session(session): + """Python "with" handler for defining a default session. + + This function provides a means of registering a session for handling + Tensor.eval() and Operation.run() calls. It is primarily intended for use + by session.Session, but can be used with any object that implements + the Session.run() interface. + + Use with the "with" keyword to specify that Tensor.eval() and Operation.run() + invocations within the scope of a block should be executed by a particular + session. + + The default session applies to the current thread only, so it is always + possible to inspect the call stack and determine the scope of a default + session. If you create a new thread, and wish to use the default session + in that thread, you must explicitly add a "with ops.default_session(sess):" + block in that thread's function. + + Example: + The following code examples are equivalent: + + # 1. Using the Session object directly: + sess = ... + c = tf.constant(5.0) + sess.run(c) + + # 2. Using default_session(): + sess = ... + with ops.default_session(sess): + c = tf.constant(5.0) + result = c.eval() + + # 3. Overriding default_session(): + sess = ... + with ops.default_session(sess): + c = tf.constant(5.0) + with ops.default_session(...): + c.eval(session=sess) + + Args: + session: The session to be installed as the default session. + + Returns: + A context manager for the default session. + """ + return _default_session_stack.get_controller(weakref.ref(session)) + + +def get_default_session(): + """Returns the default session for the current thread. + + The returned `Session` will be the innermost session on which a + `Session` or `Session.as_default()` context has been entered. + + *N.B.* The default session is a property of the current thread. If you + create a new thread, and wish to use the default session in that + thread, you must explicitly add a `with sess.as_default():` in that + thread's function. + + Returns: + The default `Session` being used in the current thread. + """ + ref = _default_session_stack.get_default() + if ref is None: + # No default session has been registered. + return None + else: + # De-reference ref. + ret = ref() + if ret is None: + # This should never happen with the current session implementations. + raise RuntimeError("Default session has been garbage collected.") + return ret + + +def _eval_using_default_session(tensors, feed_dict, graph, session=None): + """Uses the default session to evaluate one or more tensors. + + Args: + tensors: A single Tensor, or a list of Tensor objects. + feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, + numpy ndarrays, TensorProtos, or strings. + graph: The graph in which the tensors are defined. + session: (Optional) A different session to use to evaluate "tensors". + + Returns: + Either a single numpy ndarray if "tensors" is a single tensor; or a list + of numpy ndarrays that each correspond to the respective element in + "tensors". + + Raises: + ValueError: If no default session is available; the default session + does not have "graph" as its graph; or if "session" is specified, + and it does not have "graph" as its graph. + """ + if session is None: + session = get_default_session() + if session is None: + raise ValueError("Cannot evaluate tensor using eval(): No default " + "session is registered. Use 'with " + "DefaultSession(sess)' or pass an explicit session to " + "eval(session=sess)") + if session.graph is not graph: + raise ValueError("Cannot use the default session to evaluate tensor: " + "the tensor's graph is different from the session's " + "graph. Pass an explicit session to " + "eval(session=sess).") + else: + if session.graph is not graph: + raise ValueError("Cannot use the given session to evaluate tensor: " + "the tensor's graph is different from the session's " + "graph.") + return session.run(tensors, feed_dict) + + +def _run_using_default_session(operation, feed_dict, graph, session=None): + """Uses the default session to run "operation". + + Args: + operation: The Operation to be run. + feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, + numpy ndarrays, TensorProtos, or strings. + graph: The graph in which "operation" is defined. + session: (Optional) A different session to use to run "operation". + + Raises: + ValueError: If no default session is available; the default session + does not have "graph" as its graph; or if "session" is specified, + and it does not have "graph" as its graph. + """ + if session is None: + session = get_default_session() + if session is None: + raise ValueError("Cannot execute operation using Run(): No default " + "session is registered. Use 'with " + "default_session(sess)' or pass an explicit session to " + "Run(session=sess)") + if session.graph is not graph: + raise ValueError("Cannot use the default session to execute operation: " + "the operation's graph is different from the " + "session's graph. Pass an explicit session to " + "Run(session=sess).") + else: + if session.graph is not graph: + raise ValueError("Cannot use the given session to execute operation: " + "the operation's graph is different from the session's " + "graph.") + session.run(operation, feed_dict) + + +class _DefaultGraphStack(_DefaultStack): + """A thread-local stack of objects for providing an implicit default graph.""" + + def __init__(self): + super(_DefaultGraphStack, self).__init__() + self._global_default_graph = None + + def get_default(self): + """Override that returns a global default if the stack is empty.""" + ret = super(_DefaultGraphStack, self).get_default() + if ret is None: + ret = self._GetGlobalDefaultGraph() + return ret + + def _GetGlobalDefaultGraph(self): + if self._global_default_graph is None: + # TODO(mrry): Perhaps log that the default graph is being used, or set + # provide some other feedback to prevent confusion when a mixture of + # the global default graph and an explicit graph are combined in the + # same process. + self._global_default_graph = Graph() + return self._global_default_graph + + def reset(self): + super(_DefaultGraphStack, self).reset() + self._global_default_graph = None + +_default_graph_stack = _DefaultGraphStack() + + +def reset_default_graph(): + """Clears the default graph stack and resets the global default graph. + + *N.B.* The default graph is a property of the current thread. This + function applies only to the current thread. + """ + _default_graph_stack.reset() + + +def get_default_graph(): + """Returns the default graph for the current thread. + + The returned graph will be the innermost graph on which a + `Graph.as_default()` context has been entered, or a global default + graph if none has been explicitly created. + + *N.B.* The default graph is a property of the current thread. If you + create a new thread, and wish to use the default graph in that + thread, you must explicitly add a `with g.as_default():` in that + thread's function. + + Returns: + The default `Graph` being used in the current thread. + """ + return _default_graph_stack.get_default() + + +def _get_graph_from_inputs(op_input_list, graph=None): + """Returns the appropriate graph to use for the given inputs. + + This library method provides a consistent algorithm for choosing the graph + in which an Operation should be constructed: + + 1. If the "graph" is specified explicitly, we validate that all of the inputs + in "op_input_list" are compatible with that graph. + 2. Otherwise, we attempt to select a graph from the first Operation- + or Tensor-valued input in "op_input_list", and validate that all other + such inputs are in the same graph. + 3. If the graph was not specified and it could not be inferred from + "op_input_list", we attempt to use the default graph. + + Args: + op_input_list: A list of inputs to an operation, which may include Tensor + and Operation objects. + graph: (Optional) The explicit graph to use. + + Raises: + TypeError: If op_input_list is not a list or tuple, or if graph is not a + Graph. + ValueError: If a graph is explicitly passed and not all inputs are from it, + or if the inputs are from multiple graphs, or we could not find a graph + and there was no default graph. + + Returns: + The appropriate graph to use for the given inputs. + """ + if not isinstance(op_input_list, (list, tuple)): + raise TypeError("The op_input_list must be a list or tuple") + + # 1. If the graph is specified explicitly, we validate that all of the inputs + # are compatible with that graph. + if graph is not None: + if not isinstance(graph, Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + for op_input in op_input_list: + if isinstance(op_input, Operation): + if op_input.graph is not graph: + raise ValueError("Operation %s is not from the passed-in graph" + % op_input) + elif isinstance(op_input, Tensor): + if op_input.graph is not graph: + raise ValueError("Tensor %s is not from the passed-in graph" + % op_input) + return graph + + # 2. Otherwise, we attempt to select a graph from one of the Operation- + # or Tensor-valued inputs. + original_input = None + for op_input in op_input_list: + if isinstance(op_input, (Operation, Tensor)): + if original_input is None: + original_input = op_input + else: + assert_same_graph([original_input, op_input]) + if original_input is not None: + return original_input.graph + + # 3. If all else fails, we use the default graph, which is always there. + return get_default_graph() + + +class GraphKeys(object): + """Standard names to use for graph collections. + + The standard library uses various well-known names to collect and + retrieve values associated with a graph. For example, the + `tf.Optimizer` subclasses default to optimizing the variables + collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is + specified, but it is also possible to pass an explicit list of + variables. + + The following standard keys are defined: + + * `VARIABLES`: the `Variable` objects that comprise a model, and + must be saved and restored together. See + [`tf.all_variables()`](state_ops.md#all_variables) for more details. + * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will + be trained by an optimizer. See + [`tf.trainable_variables()`](state_ops.md#trainable_variables) + for more details. + * `SUMMARIES`: the summary `Tensor` objects that have been created + in the graph. See [`tf.merge_all_summaries()`](train.md#merge_all_summaries) + for more details. + * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to + produce input for a computation. See + [`tf.start_queue_runners()`](train.md#start_queue_runners) for more details. + """ + + # Key to collect variables.Variable objects that must be saved and restored + # by the model. + VARIABLES = "variables" + # Key to collect variables.Variable objects that will be trained by the + # optimizers. + TRAINABLE_VARIABLES = "trainable_variables" + # Key to collect summaries. + SUMMARIES = "summaries" + # Key to collect QueueRunners. + QUEUE_RUNNERS = "queue_runners" + # Key to collect table initializers. + TABLE_INITIALIZERS = "table_initializer" + + +def add_to_collection(name, value): + """Wrapper for `Graph.add_to_collection()` using the default graph. + + See [`Graph.add_to_collection()`](framework.md#Graph.add_to_collection) + for more details. + + Args: + name: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + value: The value to add to the collection. + """ + get_default_graph().add_to_collection(name, value) + + +def get_collection(key, scope=None): + """Wrapper for `Graph.get_collection()` using the default graph. + + See [`Graph.get_collection()`](framework.md#Graph.get_collection) + for more details. + + Args: + key: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + scope: (Optional.) If supplied, the resulting list is filtered to include + only items whose name begins with this string. + + Returns: + The list of values in the collection with the given `name`, or + an empty list if no value has been added to that collection. The + list contains the values in the order under which they were + collected. + """ + return get_default_graph().get_collection(key, scope) + + +# pylint: disable=g-doc-return-or-yield +@contextlib.contextmanager +def op_scope(values, name, default_name): + """Returns a context manager for use when defining a Python op. + + This context manager validates that the given `values` are from the + same graph, ensures that that graph is the default graph, and pushes a + name scope. + + For example, to define a new Python op called `my_op`: + + ```python + def my_op(a, b, c, name=None): + with tf.op_scope([a, b, c], name, "MyOp") as scope: + a = tf.convert_to_tensor(a, name="a") + b = tf.convert_to_tensor(b, name="b") + c = tf.convert_to_tensor(c, name="c") + # Define some computation that uses `a`, `b`, and `c`. + return foo_op(..., name=scope) + ``` + + Args: + values: The list of `Tensor` arguments that are passed to the op function. + name: The name argument that is passed to the op function. + default_name: The default name to use if the `name` argument is `None`. + + Returns: + A context manager for use in defining a Python op. + """ + g = _get_graph_from_inputs(values) + n = default_name if name is None else name + with g.as_default(), g.name_scope(n) as scope: + yield scope +# pylint: enable=g-doc-return-or-yield |