aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/session.py
blob: 7da9b41cf42df09f9a9bca5e01d01c009e78d515 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
"""A client interface for TensorFlow."""

import re
import sys
import threading

import tensorflow.python.platform

import numpy as np

from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import logging


class SessionInterface(object):
  """Base class for implementations of TensorFlow client sessions."""

  @property
  def graph(self):
    """The underlying TensorFlow graph, to be used in building Operations."""
    raise NotImplementedError('graph')

  @property
  def sess_str(self):
    """The TensorFlow process to which this session will connect."""
    raise NotImplementedError('sess_str')

  def run(self, fetches, feed_dict=None):
    """Runs operations in the session. See `Session.run()` for details."""
    raise NotImplementedError('Run')


class BaseSession(SessionInterface):
  """A class for interacting with a TensorFlow computation.

  The BaseSession enables incremental graph building with inline
  execution of Operations and evaluation of Tensors.
  """

  def __init__(self, target='', graph=None, config=None):
    """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      RuntimeError: If an error occurs while creating the TensorFlow
        session.
    """
    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      self._graph = graph

    self._opened = False
    self._closed = False

    self._current_version = 0
    self._extend_lock = threading.Lock()
    self._target = target

    self._session = None

    try:
      opts = tf_session.TF_NewSessionOptions(target=target, config=config)
      status = tf_session.TF_NewStatus()
      self._session = tf_session.TF_NewSession(opts, status)
      if tf_session.TF_GetCode(status) != 0:
        message = tf_session.TF_Message(status)
        raise RuntimeError(message)

    finally:
      tf_session.TF_DeleteSessionOptions(opts)
      tf_session.TF_DeleteStatus(status)

  def close(self):
    """Closes this session.

    Calling this method frees all resources associated with the session.

    Raises:
      RuntimeError: If an error occurs while closing the session.
    """
    with self._extend_lock:
      if self._opened and not self._closed:
        self._closed = True
        try:
          status = tf_session.TF_NewStatus()
          tf_session.TF_CloseSession(self._session, status)
          if tf_session.TF_GetCode(status) != 0:
            raise RuntimeError(tf_session.TF_Message(status))
        finally:
          tf_session.TF_DeleteStatus(status)

  def __del__(self):
    self.close()
    try:
      status = tf_session.TF_NewStatus()
      if self._session is not None:
        tf_session.TF_DeleteSession(self._session, status)
        if tf_session.TF_GetCode(status) != 0:
          raise RuntimeError(tf_session.TF_Message(status))
        self._session = None
    finally:
      tf_session.TF_DeleteStatus(status)

  @property
  def graph(self):
    """The graph that was launched in this session."""
    return self._graph

  @property
  def graph_def(self):
    """A serializable version of the underlying TensorFlow graph.

    Returns:
      A graph_pb2.GraphDef proto containing nodes for all of the Operations in
      the underlying TensorFlow graph.
    """
    return self._graph.as_graph_def()

  @property
  def sess_str(self):
    return self._target

  def as_default(self):
    """Returns a context manager that makes this object the default session.

    Use with the `with` keyword to specify that calls to
    [`Operation.run()`](framework.md#Operation.run) or
    [`Tensor.run()`](framework.md#Tensor.run) should be executed in
    this session.

    ```python
    c = tf.constant(..)
    sess = tf.Session()

    with sess.as_default():
      assert tf.get_default_session() is sess
      print c.eval()
    ```

    To get the current default session, use
    [`tf.get_default_session()`](#get_default_session).


    *N.B.* The `as_default` context manager *does not* close the
    session when you exit the context, and you must close the session
    explicitly.

    ```python
    c = tf.constant(...)
    sess = tf.Session()
    with sess.as_default():
      print c.eval()
    # ...
    with sess.as_default():
      print c.eval()

    sess.close()
    ```

    Alternatively, you can use `with tf.Session():` to create a
    session that is automatically closed on exiting the context,
    including when an uncaught exception is raised.

    *N.B.* The default graph is a property of the current thread. If you
    create a new thread, and wish to use the default session in that
    thread, you must explicitly add a `with sess.as_default():` in that
    thread's function.

    Returns:
      A context manager using this session as the default session.

    """
    return ops.default_session(self)

  # Eventually, this registration could be opened up to support custom
  # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn),
  # where the signatures are:
  #   fetch_fn : Type -> (list of Tensors,
  #                       lambda: list of fetched np.ndarray -> TypeVal)
  #   feed_fn  : Type, TypeVal -> list of (Tensor, value)
  # Conceptually, fetch_fn describes how to expand fetch into its
  # component Tensors and how to contracting the fetched results back into
  # a single return value. feed_fn describes how to unpack a single fed
  # value and map it to feeds of a Tensor and its corresponding value.
  # pylint: disable=g-long-lambda
  _REGISTERED_EXPANSIONS = [
      # SparseTensors are fetched as SparseTensorValues. They can be fed
      # SparseTensorValues or normal tuples.
      (ops.SparseTensor,
       lambda fetch: (
           [fetch.indices, fetch.values, fetch.shape],
           lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
       lambda feed, feed_val: list(zip(
           [feed.indices, feed.values, feed.shape], feed_val))),
      # The default catches all types and performs no expansions.
      (object,
       lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
       lambda feed, feed_val: [(feed, feed_val)])]
  # pylint: enable=g-long-lambda

  def run(self, fetches, feed_dict=None):
    """Runs the operations and evaluates the tensors in `fetches`.

    This method runs one "step" of TensorFlow computation, by
    running the necessary graph fragment to execute every `Operation`
    and evaluate every `Tensor` in `fetches`, substituting the values in
    `feed_dict` for the corresponding input values.

    The `fetches` argument may be a list of graph elements or a single
    graph element, and these determine the return value of this
    method. A graph element can be one of the following types:

    * If the *i*th element of `fetches` is an
      [`Operation`](framework.md#Operation), the *i*th return value
      will be `None`.
    * If the *i*th element of `fetches` is a
      [`Tensor`](framework.md#Tensor), the *i*th return value will
      be a numpy ndarray containing the value of that tensor.
    * If the *i*th element of `fetches` is a
      [`SparseTensor`](sparse_ops.md#SparseTensor), the *i*th
      return value will be a
      [`SparseTensorValue`](sparse_ops.md#SparseTensorValue)
      containing the value of that sparse tensor.

    The optional `feed_dict` argument allows the caller to override
    the value of tensors in the graph. Each key in `feed_dict` can be
    one of the following types:

    * If the key is a [`Tensor`](framework.md#Tensor), the
      value may be a Python scalar, string, list, or numpy ndarray
      that can be converted to the same `dtype` as that
      tensor. Additionally, if the key is a
      [placeholder](io_ops.md#placeholder), the shape of the value
      will be checked for compatibility with the placeholder.
    * If the key is a [`SparseTensor`](sparse_ops.md#SparseTensor),
      the value should be a
      [`SparseTensorValue`](sparse_ops.md#SparseTensorValue).

    Args:
      fetches: A single graph element, or a list of graph elements
        (described above).
      feed_dict: A dictionary that maps graph elements to values
        (described above).

    Returns:
      Either a single value if `fetches` is a single graph element, or
      a list of values if `fetches` is a list (described above).

    Raises:
      RuntimeError: If this `Session` is in an invalid state (e.g. has been
        closed).
      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
      ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
        `Tensor` that doesn't exist.

    """
    def _fetch_fn(fetch):
      for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
        if isinstance(fetch, tensor_type):
          return fetch_fn(fetch)
      raise TypeError('Fetch argument %r has invalid type %r'
                      % (fetch, type(fetch)))

    def _feed_fn(feed, feed_val):
      for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS:
        if isinstance(feed, tensor_type):
          return feed_fn(feed, feed_val)
      raise TypeError('Feed argument %r has invalid type %r'
                      % (feed, type(feed)))

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')

    # Validate and process fetches.
    is_list_fetch = isinstance(fetches, (list, tuple))
    if not is_list_fetch:
      fetches = [fetches]

    unique_fetch_targets = set()
    target_list = []

    fetch_info = []
    for fetch in fetches:
      subfetches, fetch_contraction_fn = _fetch_fn(fetch)
      subfetch_names = []
      for subfetch in subfetches:
        try:
          fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
                                                allow_operation=True)
          if isinstance(fetch_t, ops.Operation):
            target_list.append(fetch_t.name)
          else:
            subfetch_names.append(fetch_t.name)
        except TypeError as e:
          raise TypeError('Fetch argument %r of %r has invalid type %r, '
                          'must be a string or Tensor. (%s)'
                          % (subfetch, fetch, type(subfetch), e.message))
        except ValueError as e:
          raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
                           'Tensor. (%s)' % (subfetch, fetch, e.message))
        except KeyError as e:
          raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
                           'Tensor. (%s)' % (subfetch, fetch, e.message))
      unique_fetch_targets.update(subfetch_names)
      fetch_info.append((subfetch_names, fetch_contraction_fn))

    unique_fetch_targets = list(unique_fetch_targets)

    # Create request.
    feed_dict_string = {}

    # Validate and process feed_dict.
    if feed_dict:
      for feed, feed_val in feed_dict.iteritems():
        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
          try:
            subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
                                                    allow_operation=False)
          except Exception as e:
            e.message = ('Cannot interpret feed_dict key as Tensor: '
                         + e.message)
            e.args = (e.message,)
            raise e
          np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype)
          if subfeed_t.op.type == 'Placeholder':
            if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
              raise ValueError(
                  'Cannot feed value of shape %r for Tensor %r, '
                  'which has shape %r'
                  % (np_val.shape, subfeed_t.name,
                     tuple(subfeed_t.get_shape().dims)))
          feed_dict_string[str(subfeed_t.name)] = np_val

    # Run request and get response.
    results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)

    # User may have fetched the same tensor multiple times, but we
    # only fetch them from the runtime once.  Furthermore, they may
    # be wrapped as a tuple of tensors.  Here we map the results back
    # to what the client asked for.
    fetched_results = dict(zip(unique_fetch_targets, results))
    ret = []
    for fetch_names, fetch_contraction_fn in fetch_info:
      if fetch_names:
        fetched_vals = [fetched_results[name] for name in fetch_names]
        ret.append(fetch_contraction_fn(fetched_vals))
      else:
        ret.append(None)

    if is_list_fetch:
      return ret
    else:
      return ret[0]

  # Captures the name of a node in an error status.
  _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')

  def _do_run(self, target_list, fetch_list, feed_dict):
    """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of strings corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of strings corresponding to names of tensors to be
        fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names to numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
    try:
      # Ensure any changes to the graph are reflected in the runtime.
      with self._extend_lock:
        if self._graph.version > self._current_version:
          graph_def = self._graph.as_graph_def(
              from_version=self._current_version)

          try:
            status = tf_session.TF_NewStatus()
            tf_session.TF_ExtendGraph(
                self._session, graph_def.SerializeToString(), status)
            if tf_session.TF_GetCode(status) != 0:
              raise RuntimeError(tf_session.TF_Message(status))
            self._opened = True
          finally:
            tf_session.TF_DeleteStatus(status)

          self._current_version = self._graph.version

      return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                               target_list)

    except tf_session.StatusNotOK as e:
      e_type, e_value, e_traceback = sys.exc_info()
      m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
      if m is not None:
        node_name = m.group(1)
        node_def = None
        try:
          op = self._graph.get_operation_by_name(node_name)
          node_def = op.node_def
        except KeyError:
          op = None
        # pylint: disable=protected-access
        raise errors._make_specific_exception(node_def, op, e.error_message,
                                              e.code)
        # pylint: enable=protected-access
      raise e_type, e_value, e_traceback


class Session(BaseSession):
  """A class for running TensorFlow operations.

  A `Session` object encapsulates the environment in which `Operation`
  objects are executed, and `Tensor` objects are evaluated. For
  example:

  ```python
  # Build a graph.
  a = tf.constant(5.0)
  b = tf.constant(6.0)
  c = a * b

  # Launch the graph in a session.
  sess = tf.Session()

  # Evaluate the tensor `c`.
  print sess.run(c)
  ```

  A session may own resources, such as
  [variables](state_ops.md#Variable), [queues](io_ops.md#QueueBase),
  and [readers](io_ops.md#ReaderBase). It is important to release
  these resources when they are no longer required. To do this, either
  invoke the [`close()`](#Session.close) method on the session, or use
  the session as a context manager. The following two examples are
  equivalent:

  ```python
  # Using the `close()` method.
  sess = tf.Session()
  sess.run(...)
  sess.close()

  # Using the context manager.
  with tf.Session() as sess:
    sess.run(...)
  ```

  The [`ConfigProto`]
  (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
  protocol buffer exposes various configuration options for a
  session. For example, to create a session that uses soft constraints
  for device placement, and log the resulting placement decisions,
  create a session as follows:

  ```python
  # Launch the graph in a session that allows soft device placement and
  # logs the placement decisions.
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True))
  ```

  @@__init__
  @@run
  @@close

  @@graph

  @@as_default

  """

  def __init__(self, target='', graph=None, config=None):
    """Creates a new TensorFlow session.

    If no `graph` argument is specified when constructing the session,
    the default graph will be launched in the session. If you are
    using more than one graph (created with `tf.Graph()` in the same
    process, you will have to use different sessions for each graph,
    but each graph can be used in multiple sessions. In this case, it
    is often clearer to pass the graph to be launched explicitly to
    the session constructor.

    Args:
      target: (Optional.) The execution engine to connect to.
        Defaults to using an in-process engine. At present, no value
        other than the empty string is supported.
      graph: (Optional.) The `Graph` to be launched (described above).
      config: (Optional.) A [`ConfigProto`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
        protocol buffer with configuration options for the session.

    """
    super(Session, self).__init__(target, graph, config=config)
    self._context_managers = [self.graph.as_default(), self.as_default()]

  def __enter__(self):
    for context_manager in self._context_managers:
      context_manager.__enter__()
    return self

  def __exit__(self, exec_type, exec_value, exec_tb):
    if exec_type is errors.OpError:
      logging.error('Session closing due to OpError: %s', (exec_value,))

    for context_manager in reversed(self._context_managers):
      context_manager.__exit__(exec_type, exec_value, exec_tb)

    self.close()


class InteractiveSession(BaseSession):
  """A TensorFlow `Session` for use in interactive contexts, such as a shell.

  In some cases, such as interactive shells and IPython notebooks, it is
  useful to be able to define a `Session` without using a with block: this
  style enables statements to be executed immediately, rather than at the
  termination of the block. In that case, it must be closed using
  `Session.close()`. For example:

  ```python
  sess = InteractiveSession()
  a = tf.constant(5.0)
  b = tf.constant(6.0)
  c = a * b
  print c.run()
  sess.close()
  ```

  @@__init__
  @@close
  """

  def __init__(self, target='', graph=None):
    """Initializes an `InteractiveSession` object similar to `Session`.

    Args:
      target: Optional. The TensorFlow execution engine to connect to.
      graph: Optional. The `Graph` object to be used. If this argument is None,
        the default graph will be used.
    """
    super(InteractiveSession, self).__init__(target, graph)
    self._default_session = self.as_default()
    self._default_session.__enter__()
    self._explicit_graph = graph
    if self._explicit_graph is not None:
      self._default_graph = graph.as_default()
      self._default_graph.__enter__()

  def close(self):
    """Closes an `InteractiveSession`."""
    super(InteractiveSession, self).close()
    if self._explicit_graph is not None:
      self._default_graph.__exit__(None, None, None)
    self._default_session.__exit__(None, None, None)