From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/python/client/session.py | 567 ++++++++++++++++++++++++++++++++++++ 1 file changed, 567 insertions(+) create mode 100644 tensorflow/python/client/session.py (limited to 'tensorflow/python/client/session.py') diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py new file mode 100644 index 0000000000..7da9b41cf4 --- /dev/null +++ b/tensorflow/python/client/session.py @@ -0,0 +1,567 @@ +"""A client interface for TensorFlow.""" + +import re +import sys +import threading + +import tensorflow.python.platform + +import numpy as np + +from tensorflow.python import pywrap_tensorflow as tf_session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import logging + + +class SessionInterface(object): + """Base class for implementations of TensorFlow client sessions.""" + + @property + def graph(self): + """The underlying TensorFlow graph, to be used in building Operations.""" + raise NotImplementedError('graph') + + @property + def sess_str(self): + """The TensorFlow process to which this session will connect.""" + raise NotImplementedError('sess_str') + + def run(self, fetches, feed_dict=None): + """Runs operations in the session. See `Session.run()` for details.""" + raise NotImplementedError('Run') + + +class BaseSession(SessionInterface): + """A class for interacting with a TensorFlow computation. + + The BaseSession enables incremental graph building with inline + execution of Operations and evaluation of Tensors. + """ + + def __init__(self, target='', graph=None, config=None): + """Constructs a new TensorFlow session. + + Args: + target: (Optional) The TensorFlow execution engine to connect to. + graph: (Optional) The graph to be used. If this argument is None, + the default graph will be used. + config: (Optional) ConfigProto proto used to configure the session. + + Raises: + RuntimeError: If an error occurs while creating the TensorFlow + session. + """ + if graph is None: + self._graph = ops.get_default_graph() + else: + self._graph = graph + + self._opened = False + self._closed = False + + self._current_version = 0 + self._extend_lock = threading.Lock() + self._target = target + + self._session = None + + try: + opts = tf_session.TF_NewSessionOptions(target=target, config=config) + status = tf_session.TF_NewStatus() + self._session = tf_session.TF_NewSession(opts, status) + if tf_session.TF_GetCode(status) != 0: + message = tf_session.TF_Message(status) + raise RuntimeError(message) + + finally: + tf_session.TF_DeleteSessionOptions(opts) + tf_session.TF_DeleteStatus(status) + + def close(self): + """Closes this session. + + Calling this method frees all resources associated with the session. + + Raises: + RuntimeError: If an error occurs while closing the session. + """ + with self._extend_lock: + if self._opened and not self._closed: + self._closed = True + try: + status = tf_session.TF_NewStatus() + tf_session.TF_CloseSession(self._session, status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(tf_session.TF_Message(status)) + finally: + tf_session.TF_DeleteStatus(status) + + def __del__(self): + self.close() + try: + status = tf_session.TF_NewStatus() + if self._session is not None: + tf_session.TF_DeleteSession(self._session, status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(tf_session.TF_Message(status)) + self._session = None + finally: + tf_session.TF_DeleteStatus(status) + + @property + def graph(self): + """The graph that was launched in this session.""" + return self._graph + + @property + def graph_def(self): + """A serializable version of the underlying TensorFlow graph. + + Returns: + A graph_pb2.GraphDef proto containing nodes for all of the Operations in + the underlying TensorFlow graph. + """ + return self._graph.as_graph_def() + + @property + def sess_str(self): + return self._target + + def as_default(self): + """Returns a context manager that makes this object the default session. + + Use with the `with` keyword to specify that calls to + [`Operation.run()`](framework.md#Operation.run) or + [`Tensor.run()`](framework.md#Tensor.run) should be executed in + this session. + + ```python + c = tf.constant(..) + sess = tf.Session() + + with sess.as_default(): + assert tf.get_default_session() is sess + print c.eval() + ``` + + To get the current default session, use + [`tf.get_default_session()`](#get_default_session). + + + *N.B.* The `as_default` context manager *does not* close the + session when you exit the context, and you must close the session + explicitly. + + ```python + c = tf.constant(...) + sess = tf.Session() + with sess.as_default(): + print c.eval() + # ... + with sess.as_default(): + print c.eval() + + sess.close() + ``` + + Alternatively, you can use `with tf.Session():` to create a + session that is automatically closed on exiting the context, + including when an uncaught exception is raised. + + *N.B.* The default graph is a property of the current thread. If you + create a new thread, and wish to use the default session in that + thread, you must explicitly add a `with sess.as_default():` in that + thread's function. + + Returns: + A context manager using this session as the default session. + + """ + return ops.default_session(self) + + # Eventually, this registration could be opened up to support custom + # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn), + # where the signatures are: + # fetch_fn : Type -> (list of Tensors, + # lambda: list of fetched np.ndarray -> TypeVal) + # feed_fn : Type, TypeVal -> list of (Tensor, value) + # Conceptually, fetch_fn describes how to expand fetch into its + # component Tensors and how to contracting the fetched results back into + # a single return value. feed_fn describes how to unpack a single fed + # value and map it to feeds of a Tensor and its corresponding value. + # pylint: disable=g-long-lambda + _REGISTERED_EXPANSIONS = [ + # SparseTensors are fetched as SparseTensorValues. They can be fed + # SparseTensorValues or normal tuples. + (ops.SparseTensor, + lambda fetch: ( + [fetch.indices, fetch.values, fetch.shape], + lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)), + lambda feed, feed_val: list(zip( + [feed.indices, feed.values, feed.shape], feed_val))), + # The default catches all types and performs no expansions. + (object, + lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), + lambda feed, feed_val: [(feed, feed_val)])] + # pylint: enable=g-long-lambda + + def run(self, fetches, feed_dict=None): + """Runs the operations and evaluates the tensors in `fetches`. + + This method runs one "step" of TensorFlow computation, by + running the necessary graph fragment to execute every `Operation` + and evaluate every `Tensor` in `fetches`, substituting the values in + `feed_dict` for the corresponding input values. + + The `fetches` argument may be a list of graph elements or a single + graph element, and these determine the return value of this + method. A graph element can be one of the following types: + + * If the *i*th element of `fetches` is an + [`Operation`](framework.md#Operation), the *i*th return value + will be `None`. + * If the *i*th element of `fetches` is a + [`Tensor`](framework.md#Tensor), the *i*th return value will + be a numpy ndarray containing the value of that tensor. + * If the *i*th element of `fetches` is a + [`SparseTensor`](sparse_ops.md#SparseTensor), the *i*th + return value will be a + [`SparseTensorValue`](sparse_ops.md#SparseTensorValue) + containing the value of that sparse tensor. + + The optional `feed_dict` argument allows the caller to override + the value of tensors in the graph. Each key in `feed_dict` can be + one of the following types: + + * If the key is a [`Tensor`](framework.md#Tensor), the + value may be a Python scalar, string, list, or numpy ndarray + that can be converted to the same `dtype` as that + tensor. Additionally, if the key is a + [placeholder](io_ops.md#placeholder), the shape of the value + will be checked for compatibility with the placeholder. + * If the key is a [`SparseTensor`](sparse_ops.md#SparseTensor), + the value should be a + [`SparseTensorValue`](sparse_ops.md#SparseTensorValue). + + Args: + fetches: A single graph element, or a list of graph elements + (described above). + feed_dict: A dictionary that maps graph elements to values + (described above). + + Returns: + Either a single value if `fetches` is a single graph element, or + a list of values if `fetches` is a list (described above). + + Raises: + RuntimeError: If this `Session` is in an invalid state (e.g. has been + closed). + TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. + ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a + `Tensor` that doesn't exist. + + """ + def _fetch_fn(fetch): + for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS: + if isinstance(fetch, tensor_type): + return fetch_fn(fetch) + raise TypeError('Fetch argument %r has invalid type %r' + % (fetch, type(fetch))) + + def _feed_fn(feed, feed_val): + for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS: + if isinstance(feed, tensor_type): + return feed_fn(feed, feed_val) + raise TypeError('Feed argument %r has invalid type %r' + % (feed, type(feed))) + + # Check session. + if self._closed: + raise RuntimeError('Attempted to use a closed Session.') + + # Validate and process fetches. + is_list_fetch = isinstance(fetches, (list, tuple)) + if not is_list_fetch: + fetches = [fetches] + + unique_fetch_targets = set() + target_list = [] + + fetch_info = [] + for fetch in fetches: + subfetches, fetch_contraction_fn = _fetch_fn(fetch) + subfetch_names = [] + for subfetch in subfetches: + try: + fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True, + allow_operation=True) + if isinstance(fetch_t, ops.Operation): + target_list.append(fetch_t.name) + else: + subfetch_names.append(fetch_t.name) + except TypeError as e: + raise TypeError('Fetch argument %r of %r has invalid type %r, ' + 'must be a string or Tensor. (%s)' + % (subfetch, fetch, type(subfetch), e.message)) + except ValueError as e: + raise ValueError('Fetch argument %r of %r cannot be interpreted as a ' + 'Tensor. (%s)' % (subfetch, fetch, e.message)) + except KeyError as e: + raise ValueError('Fetch argument %r of %r cannot be interpreted as a ' + 'Tensor. (%s)' % (subfetch, fetch, e.message)) + unique_fetch_targets.update(subfetch_names) + fetch_info.append((subfetch_names, fetch_contraction_fn)) + + unique_fetch_targets = list(unique_fetch_targets) + + # Create request. + feed_dict_string = {} + + # Validate and process feed_dict. + if feed_dict: + for feed, feed_val in feed_dict.iteritems(): + for subfeed, subfeed_val in _feed_fn(feed, feed_val): + try: + subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, + allow_operation=False) + except Exception as e: + e.message = ('Cannot interpret feed_dict key as Tensor: ' + + e.message) + e.args = (e.message,) + raise e + np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype) + if subfeed_t.op.type == 'Placeholder': + if not subfeed_t.get_shape().is_compatible_with(np_val.shape): + raise ValueError( + 'Cannot feed value of shape %r for Tensor %r, ' + 'which has shape %r' + % (np_val.shape, subfeed_t.name, + tuple(subfeed_t.get_shape().dims))) + feed_dict_string[str(subfeed_t.name)] = np_val + + # Run request and get response. + results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) + + # User may have fetched the same tensor multiple times, but we + # only fetch them from the runtime once. Furthermore, they may + # be wrapped as a tuple of tensors. Here we map the results back + # to what the client asked for. + fetched_results = dict(zip(unique_fetch_targets, results)) + ret = [] + for fetch_names, fetch_contraction_fn in fetch_info: + if fetch_names: + fetched_vals = [fetched_results[name] for name in fetch_names] + ret.append(fetch_contraction_fn(fetched_vals)) + else: + ret.append(None) + + if is_list_fetch: + return ret + else: + return ret[0] + + # Captures the name of a node in an error status. + _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') + + def _do_run(self, target_list, fetch_list, feed_dict): + """Runs a step based on the given fetches and feeds. + + Args: + target_list: A list of strings corresponding to names of tensors + or operations to be run to, but not fetched. + fetch_list: A list of strings corresponding to names of tensors to be + fetched and operations to be run. + feed_dict: A dictionary that maps tensor names to numpy ndarrays. + + Returns: + A list of numpy ndarrays, corresponding to the elements of + `fetch_list`. If the ith element of `fetch_list` contains the + name of an operation, the first Tensor output of that operation + will be returned for that element. + """ + try: + # Ensure any changes to the graph are reflected in the runtime. + with self._extend_lock: + if self._graph.version > self._current_version: + graph_def = self._graph.as_graph_def( + from_version=self._current_version) + + try: + status = tf_session.TF_NewStatus() + tf_session.TF_ExtendGraph( + self._session, graph_def.SerializeToString(), status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(tf_session.TF_Message(status)) + self._opened = True + finally: + tf_session.TF_DeleteStatus(status) + + self._current_version = self._graph.version + + return tf_session.TF_Run(self._session, feed_dict, fetch_list, + target_list) + + except tf_session.StatusNotOK as e: + e_type, e_value, e_traceback = sys.exc_info() + m = BaseSession._NODEDEF_NAME_RE.search(e.error_message) + if m is not None: + node_name = m.group(1) + node_def = None + try: + op = self._graph.get_operation_by_name(node_name) + node_def = op.node_def + except KeyError: + op = None + # pylint: disable=protected-access + raise errors._make_specific_exception(node_def, op, e.error_message, + e.code) + # pylint: enable=protected-access + raise e_type, e_value, e_traceback + + +class Session(BaseSession): + """A class for running TensorFlow operations. + + A `Session` object encapsulates the environment in which `Operation` + objects are executed, and `Tensor` objects are evaluated. For + example: + + ```python + # Build a graph. + a = tf.constant(5.0) + b = tf.constant(6.0) + c = a * b + + # Launch the graph in a session. + sess = tf.Session() + + # Evaluate the tensor `c`. + print sess.run(c) + ``` + + A session may own resources, such as + [variables](state_ops.md#Variable), [queues](io_ops.md#QueueBase), + and [readers](io_ops.md#ReaderBase). It is important to release + these resources when they are no longer required. To do this, either + invoke the [`close()`](#Session.close) method on the session, or use + the session as a context manager. The following two examples are + equivalent: + + ```python + # Using the `close()` method. + sess = tf.Session() + sess.run(...) + sess.close() + + # Using the context manager. + with tf.Session() as sess: + sess.run(...) + ``` + + The [`ConfigProto`] + (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto) + protocol buffer exposes various configuration options for a + session. For example, to create a session that uses soft constraints + for device placement, and log the resulting placement decisions, + create a session as follows: + + ```python + # Launch the graph in a session that allows soft device placement and + # logs the placement decisions. + sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, + log_device_placement=True)) + ``` + + @@__init__ + @@run + @@close + + @@graph + + @@as_default + + """ + + def __init__(self, target='', graph=None, config=None): + """Creates a new TensorFlow session. + + If no `graph` argument is specified when constructing the session, + the default graph will be launched in the session. If you are + using more than one graph (created with `tf.Graph()` in the same + process, you will have to use different sessions for each graph, + but each graph can be used in multiple sessions. In this case, it + is often clearer to pass the graph to be launched explicitly to + the session constructor. + + Args: + target: (Optional.) The execution engine to connect to. + Defaults to using an in-process engine. At present, no value + other than the empty string is supported. + graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional.) A [`ConfigProto`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto) + protocol buffer with configuration options for the session. + + """ + super(Session, self).__init__(target, graph, config=config) + self._context_managers = [self.graph.as_default(), self.as_default()] + + def __enter__(self): + for context_manager in self._context_managers: + context_manager.__enter__() + return self + + def __exit__(self, exec_type, exec_value, exec_tb): + if exec_type is errors.OpError: + logging.error('Session closing due to OpError: %s', (exec_value,)) + + for context_manager in reversed(self._context_managers): + context_manager.__exit__(exec_type, exec_value, exec_tb) + + self.close() + + +class InteractiveSession(BaseSession): + """A TensorFlow `Session` for use in interactive contexts, such as a shell. + + In some cases, such as interactive shells and IPython notebooks, it is + useful to be able to define a `Session` without using a with block: this + style enables statements to be executed immediately, rather than at the + termination of the block. In that case, it must be closed using + `Session.close()`. For example: + + ```python + sess = InteractiveSession() + a = tf.constant(5.0) + b = tf.constant(6.0) + c = a * b + print c.run() + sess.close() + ``` + + @@__init__ + @@close + """ + + def __init__(self, target='', graph=None): + """Initializes an `InteractiveSession` object similar to `Session`. + + Args: + target: Optional. The TensorFlow execution engine to connect to. + graph: Optional. The `Graph` object to be used. If this argument is None, + the default graph will be used. + """ + super(InteractiveSession, self).__init__(target, graph) + self._default_session = self.as_default() + self._default_session.__enter__() + self._explicit_graph = graph + if self._explicit_graph is not None: + self._default_graph = graph.as_default() + self._default_graph.__enter__() + + def close(self): + """Closes an `InteractiveSession`.""" + super(InteractiveSession, self).close() + if self._explicit_graph is not None: + self._default_graph.__exit__(None, None, None) + self._default_session.__exit__(None, None, None) -- cgit v1.2.3