aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/client/session.py')
-rw-r--r--tensorflow/python/client/session.py567
1 files changed, 567 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
new file mode 100644
index 0000000000..7da9b41cf4
--- /dev/null
+++ b/tensorflow/python/client/session.py
@@ -0,0 +1,567 @@
+"""A client interface for TensorFlow."""
+
+import re
+import sys
+import threading
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow as tf_session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import logging
+
+
+class SessionInterface(object):
+ """Base class for implementations of TensorFlow client sessions."""
+
+ @property
+ def graph(self):
+ """The underlying TensorFlow graph, to be used in building Operations."""
+ raise NotImplementedError('graph')
+
+ @property
+ def sess_str(self):
+ """The TensorFlow process to which this session will connect."""
+ raise NotImplementedError('sess_str')
+
+ def run(self, fetches, feed_dict=None):
+ """Runs operations in the session. See `Session.run()` for details."""
+ raise NotImplementedError('Run')
+
+
+class BaseSession(SessionInterface):
+ """A class for interacting with a TensorFlow computation.
+
+ The BaseSession enables incremental graph building with inline
+ execution of Operations and evaluation of Tensors.
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Constructs a new TensorFlow session.
+
+ Args:
+ target: (Optional) The TensorFlow execution engine to connect to.
+ graph: (Optional) The graph to be used. If this argument is None,
+ the default graph will be used.
+ config: (Optional) ConfigProto proto used to configure the session.
+
+ Raises:
+ RuntimeError: If an error occurs while creating the TensorFlow
+ session.
+ """
+ if graph is None:
+ self._graph = ops.get_default_graph()
+ else:
+ self._graph = graph
+
+ self._opened = False
+ self._closed = False
+
+ self._current_version = 0
+ self._extend_lock = threading.Lock()
+ self._target = target
+
+ self._session = None
+
+ try:
+ opts = tf_session.TF_NewSessionOptions(target=target, config=config)
+ status = tf_session.TF_NewStatus()
+ self._session = tf_session.TF_NewSession(opts, status)
+ if tf_session.TF_GetCode(status) != 0:
+ message = tf_session.TF_Message(status)
+ raise RuntimeError(message)
+
+ finally:
+ tf_session.TF_DeleteSessionOptions(opts)
+ tf_session.TF_DeleteStatus(status)
+
+ def close(self):
+ """Closes this session.
+
+ Calling this method frees all resources associated with the session.
+
+ Raises:
+ RuntimeError: If an error occurs while closing the session.
+ """
+ with self._extend_lock:
+ if self._opened and not self._closed:
+ self._closed = True
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_CloseSession(self._session, status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ def __del__(self):
+ self.close()
+ try:
+ status = tf_session.TF_NewStatus()
+ if self._session is not None:
+ tf_session.TF_DeleteSession(self._session, status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ self._session = None
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ @property
+ def graph(self):
+ """The graph that was launched in this session."""
+ return self._graph
+
+ @property
+ def graph_def(self):
+ """A serializable version of the underlying TensorFlow graph.
+
+ Returns:
+ A graph_pb2.GraphDef proto containing nodes for all of the Operations in
+ the underlying TensorFlow graph.
+ """
+ return self._graph.as_graph_def()
+
+ @property
+ def sess_str(self):
+ return self._target
+
+ def as_default(self):
+ """Returns a context manager that makes this object the default session.
+
+ Use with the `with` keyword to specify that calls to
+ [`Operation.run()`](framework.md#Operation.run) or
+ [`Tensor.run()`](framework.md#Tensor.run) should be executed in
+ this session.
+
+ ```python
+ c = tf.constant(..)
+ sess = tf.Session()
+
+ with sess.as_default():
+ assert tf.get_default_session() is sess
+ print c.eval()
+ ```
+
+ To get the current default session, use
+ [`tf.get_default_session()`](#get_default_session).
+
+
+ *N.B.* The `as_default` context manager *does not* close the
+ session when you exit the context, and you must close the session
+ explicitly.
+
+ ```python
+ c = tf.constant(...)
+ sess = tf.Session()
+ with sess.as_default():
+ print c.eval()
+ # ...
+ with sess.as_default():
+ print c.eval()
+
+ sess.close()
+ ```
+
+ Alternatively, you can use `with tf.Session():` to create a
+ session that is automatically closed on exiting the context,
+ including when an uncaught exception is raised.
+
+ *N.B.* The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default session in that
+ thread, you must explicitly add a `with sess.as_default():` in that
+ thread's function.
+
+ Returns:
+ A context manager using this session as the default session.
+
+ """
+ return ops.default_session(self)
+
+ # Eventually, this registration could be opened up to support custom
+ # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn),
+ # where the signatures are:
+ # fetch_fn : Type -> (list of Tensors,
+ # lambda: list of fetched np.ndarray -> TypeVal)
+ # feed_fn : Type, TypeVal -> list of (Tensor, value)
+ # Conceptually, fetch_fn describes how to expand fetch into its
+ # component Tensors and how to contracting the fetched results back into
+ # a single return value. feed_fn describes how to unpack a single fed
+ # value and map it to feeds of a Tensor and its corresponding value.
+ # pylint: disable=g-long-lambda
+ _REGISTERED_EXPANSIONS = [
+ # SparseTensors are fetched as SparseTensorValues. They can be fed
+ # SparseTensorValues or normal tuples.
+ (ops.SparseTensor,
+ lambda fetch: (
+ [fetch.indices, fetch.values, fetch.shape],
+ lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
+ lambda feed, feed_val: list(zip(
+ [feed.indices, feed.values, feed.shape], feed_val))),
+ # The default catches all types and performs no expansions.
+ (object,
+ lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
+ lambda feed, feed_val: [(feed, feed_val)])]
+ # pylint: enable=g-long-lambda
+
+ def run(self, fetches, feed_dict=None):
+ """Runs the operations and evaluates the tensors in `fetches`.
+
+ This method runs one "step" of TensorFlow computation, by
+ running the necessary graph fragment to execute every `Operation`
+ and evaluate every `Tensor` in `fetches`, substituting the values in
+ `feed_dict` for the corresponding input values.
+
+ The `fetches` argument may be a list of graph elements or a single
+ graph element, and these determine the return value of this
+ method. A graph element can be one of the following types:
+
+ * If the *i*th element of `fetches` is an
+ [`Operation`](framework.md#Operation), the *i*th return value
+ will be `None`.
+ * If the *i*th element of `fetches` is a
+ [`Tensor`](framework.md#Tensor), the *i*th return value will
+ be a numpy ndarray containing the value of that tensor.
+ * If the *i*th element of `fetches` is a
+ [`SparseTensor`](sparse_ops.md#SparseTensor), the *i*th
+ return value will be a
+ [`SparseTensorValue`](sparse_ops.md#SparseTensorValue)
+ containing the value of that sparse tensor.
+
+ The optional `feed_dict` argument allows the caller to override
+ the value of tensors in the graph. Each key in `feed_dict` can be
+ one of the following types:
+
+ * If the key is a [`Tensor`](framework.md#Tensor), the
+ value may be a Python scalar, string, list, or numpy ndarray
+ that can be converted to the same `dtype` as that
+ tensor. Additionally, if the key is a
+ [placeholder](io_ops.md#placeholder), the shape of the value
+ will be checked for compatibility with the placeholder.
+ * If the key is a [`SparseTensor`](sparse_ops.md#SparseTensor),
+ the value should be a
+ [`SparseTensorValue`](sparse_ops.md#SparseTensorValue).
+
+ Args:
+ fetches: A single graph element, or a list of graph elements
+ (described above).
+ feed_dict: A dictionary that maps graph elements to values
+ (described above).
+
+ Returns:
+ Either a single value if `fetches` is a single graph element, or
+ a list of values if `fetches` is a list (described above).
+
+ Raises:
+ RuntimeError: If this `Session` is in an invalid state (e.g. has been
+ closed).
+ TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
+ ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
+ `Tensor` that doesn't exist.
+
+ """
+ def _fetch_fn(fetch):
+ for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(fetch, tensor_type):
+ return fetch_fn(fetch)
+ raise TypeError('Fetch argument %r has invalid type %r'
+ % (fetch, type(fetch)))
+
+ def _feed_fn(feed, feed_val):
+ for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(feed, tensor_type):
+ return feed_fn(feed, feed_val)
+ raise TypeError('Feed argument %r has invalid type %r'
+ % (feed, type(feed)))
+
+ # Check session.
+ if self._closed:
+ raise RuntimeError('Attempted to use a closed Session.')
+
+ # Validate and process fetches.
+ is_list_fetch = isinstance(fetches, (list, tuple))
+ if not is_list_fetch:
+ fetches = [fetches]
+
+ unique_fetch_targets = set()
+ target_list = []
+
+ fetch_info = []
+ for fetch in fetches:
+ subfetches, fetch_contraction_fn = _fetch_fn(fetch)
+ subfetch_names = []
+ for subfetch in subfetches:
+ try:
+ fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
+ allow_operation=True)
+ if isinstance(fetch_t, ops.Operation):
+ target_list.append(fetch_t.name)
+ else:
+ subfetch_names.append(fetch_t.name)
+ except TypeError as e:
+ raise TypeError('Fetch argument %r of %r has invalid type %r, '
+ 'must be a string or Tensor. (%s)'
+ % (subfetch, fetch, type(subfetch), e.message))
+ except ValueError as e:
+ raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
+ 'Tensor. (%s)' % (subfetch, fetch, e.message))
+ except KeyError as e:
+ raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
+ 'Tensor. (%s)' % (subfetch, fetch, e.message))
+ unique_fetch_targets.update(subfetch_names)
+ fetch_info.append((subfetch_names, fetch_contraction_fn))
+
+ unique_fetch_targets = list(unique_fetch_targets)
+
+ # Create request.
+ feed_dict_string = {}
+
+ # Validate and process feed_dict.
+ if feed_dict:
+ for feed, feed_val in feed_dict.iteritems():
+ for subfeed, subfeed_val in _feed_fn(feed, feed_val):
+ try:
+ subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
+ allow_operation=False)
+ except Exception as e:
+ e.message = ('Cannot interpret feed_dict key as Tensor: '
+ + e.message)
+ e.args = (e.message,)
+ raise e
+ np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype)
+ if subfeed_t.op.type == 'Placeholder':
+ if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
+ raise ValueError(
+ 'Cannot feed value of shape %r for Tensor %r, '
+ 'which has shape %r'
+ % (np_val.shape, subfeed_t.name,
+ tuple(subfeed_t.get_shape().dims)))
+ feed_dict_string[str(subfeed_t.name)] = np_val
+
+ # Run request and get response.
+ results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
+
+ # User may have fetched the same tensor multiple times, but we
+ # only fetch them from the runtime once. Furthermore, they may
+ # be wrapped as a tuple of tensors. Here we map the results back
+ # to what the client asked for.
+ fetched_results = dict(zip(unique_fetch_targets, results))
+ ret = []
+ for fetch_names, fetch_contraction_fn in fetch_info:
+ if fetch_names:
+ fetched_vals = [fetched_results[name] for name in fetch_names]
+ ret.append(fetch_contraction_fn(fetched_vals))
+ else:
+ ret.append(None)
+
+ if is_list_fetch:
+ return ret
+ else:
+ return ret[0]
+
+ # Captures the name of a node in an error status.
+ _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
+
+ def _do_run(self, target_list, fetch_list, feed_dict):
+ """Runs a step based on the given fetches and feeds.
+
+ Args:
+ target_list: A list of strings corresponding to names of tensors
+ or operations to be run to, but not fetched.
+ fetch_list: A list of strings corresponding to names of tensors to be
+ fetched and operations to be run.
+ feed_dict: A dictionary that maps tensor names to numpy ndarrays.
+
+ Returns:
+ A list of numpy ndarrays, corresponding to the elements of
+ `fetch_list`. If the ith element of `fetch_list` contains the
+ name of an operation, the first Tensor output of that operation
+ will be returned for that element.
+ """
+ try:
+ # Ensure any changes to the graph are reflected in the runtime.
+ with self._extend_lock:
+ if self._graph.version > self._current_version:
+ graph_def = self._graph.as_graph_def(
+ from_version=self._current_version)
+
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_ExtendGraph(
+ self._session, graph_def.SerializeToString(), status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ self._opened = True
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ self._current_version = self._graph.version
+
+ return tf_session.TF_Run(self._session, feed_dict, fetch_list,
+ target_list)
+
+ except tf_session.StatusNotOK as e:
+ e_type, e_value, e_traceback = sys.exc_info()
+ m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
+ if m is not None:
+ node_name = m.group(1)
+ node_def = None
+ try:
+ op = self._graph.get_operation_by_name(node_name)
+ node_def = op.node_def
+ except KeyError:
+ op = None
+ # pylint: disable=protected-access
+ raise errors._make_specific_exception(node_def, op, e.error_message,
+ e.code)
+ # pylint: enable=protected-access
+ raise e_type, e_value, e_traceback
+
+
+class Session(BaseSession):
+ """A class for running TensorFlow operations.
+
+ A `Session` object encapsulates the environment in which `Operation`
+ objects are executed, and `Tensor` objects are evaluated. For
+ example:
+
+ ```python
+ # Build a graph.
+ a = tf.constant(5.0)
+ b = tf.constant(6.0)
+ c = a * b
+
+ # Launch the graph in a session.
+ sess = tf.Session()
+
+ # Evaluate the tensor `c`.
+ print sess.run(c)
+ ```
+
+ A session may own resources, such as
+ [variables](state_ops.md#Variable), [queues](io_ops.md#QueueBase),
+ and [readers](io_ops.md#ReaderBase). It is important to release
+ these resources when they are no longer required. To do this, either
+ invoke the [`close()`](#Session.close) method on the session, or use
+ the session as a context manager. The following two examples are
+ equivalent:
+
+ ```python
+ # Using the `close()` method.
+ sess = tf.Session()
+ sess.run(...)
+ sess.close()
+
+ # Using the context manager.
+ with tf.Session() as sess:
+ sess.run(...)
+ ```
+
+ The [`ConfigProto`]
+ (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
+ protocol buffer exposes various configuration options for a
+ session. For example, to create a session that uses soft constraints
+ for device placement, and log the resulting placement decisions,
+ create a session as follows:
+
+ ```python
+ # Launch the graph in a session that allows soft device placement and
+ # logs the placement decisions.
+ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
+ log_device_placement=True))
+ ```
+
+ @@__init__
+ @@run
+ @@close
+
+ @@graph
+
+ @@as_default
+
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Creates a new TensorFlow session.
+
+ If no `graph` argument is specified when constructing the session,
+ the default graph will be launched in the session. If you are
+ using more than one graph (created with `tf.Graph()` in the same
+ process, you will have to use different sessions for each graph,
+ but each graph can be used in multiple sessions. In this case, it
+ is often clearer to pass the graph to be launched explicitly to
+ the session constructor.
+
+ Args:
+ target: (Optional.) The execution engine to connect to.
+ Defaults to using an in-process engine. At present, no value
+ other than the empty string is supported.
+ graph: (Optional.) The `Graph` to be launched (described above).
+ config: (Optional.) A [`ConfigProto`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
+ protocol buffer with configuration options for the session.
+
+ """
+ super(Session, self).__init__(target, graph, config=config)
+ self._context_managers = [self.graph.as_default(), self.as_default()]
+
+ def __enter__(self):
+ for context_manager in self._context_managers:
+ context_manager.__enter__()
+ return self
+
+ def __exit__(self, exec_type, exec_value, exec_tb):
+ if exec_type is errors.OpError:
+ logging.error('Session closing due to OpError: %s', (exec_value,))
+
+ for context_manager in reversed(self._context_managers):
+ context_manager.__exit__(exec_type, exec_value, exec_tb)
+
+ self.close()
+
+
+class InteractiveSession(BaseSession):
+ """A TensorFlow `Session` for use in interactive contexts, such as a shell.
+
+ In some cases, such as interactive shells and IPython notebooks, it is
+ useful to be able to define a `Session` without using a with block: this
+ style enables statements to be executed immediately, rather than at the
+ termination of the block. In that case, it must be closed using
+ `Session.close()`. For example:
+
+ ```python
+ sess = InteractiveSession()
+ a = tf.constant(5.0)
+ b = tf.constant(6.0)
+ c = a * b
+ print c.run()
+ sess.close()
+ ```
+
+ @@__init__
+ @@close
+ """
+
+ def __init__(self, target='', graph=None):
+ """Initializes an `InteractiveSession` object similar to `Session`.
+
+ Args:
+ target: Optional. The TensorFlow execution engine to connect to.
+ graph: Optional. The `Graph` object to be used. If this argument is None,
+ the default graph will be used.
+ """
+ super(InteractiveSession, self).__init__(target, graph)
+ self._default_session = self.as_default()
+ self._default_session.__enter__()
+ self._explicit_graph = graph
+ if self._explicit_graph is not None:
+ self._default_graph = graph.as_default()
+ self._default_graph.__enter__()
+
+ def close(self):
+ """Closes an `InteractiveSession`."""
+ super(InteractiveSession, self).close()
+ if self._explicit_graph is not None:
+ self._default_graph.__exit__(None, None, None)
+ self._default_session.__exit__(None, None, None)