aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/context.py
blob: 778ff85342ddd4c0309e1a0bf92868241b526d8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import contextlib
import copy
import random
import threading

from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export

GRAPH_MODE = 0
EAGER_MODE = 1

# Default execution mode.
default_execution_mode = GRAPH_MODE

# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
# Note that we do not protect this with a lock and instead rely on python's GIL
# and the idempotent nature of writes to provide thread safety.
_device_parsing_cache = {}

_MAXINT32 = 2**31 - 1

DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
    pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
SYNC = 0
ASYNC = 1


class _EagerTensorCache(object):
  """Simple cache which evicts items based on length in a FIFO manner."""

  def __init__(self, max_items=256, max_tensor_size=10000):
    self._data = collections.OrderedDict()
    self._max_items = max_items
    self._max_tensor_size = max_tensor_size

  def put(self, key, value):
    if value._num_elements() > self._max_tensor_size:  # pylint: disable=protected-access
      return

    self._data[key] = value

    if len(self._data) > self._max_items:
      self._data.popitem(last=False)

  def get(self, key):
    return self._data.get(key, None)

  def flush(self):
    self._data = {}


# TODO(agarwal): better name ?
class _EagerContext(threading.local):
  """Thread local eager context."""

  def __init__(self):
    super(_EagerContext, self).__init__()
    self.device_spec = pydev.DeviceSpec.from_string("")
    self.device_name = self.device_spec.to_string()
    self.mode = default_execution_mode
    self.is_eager = default_execution_mode == EAGER_MODE
    self.scope_name = ""
    self.recording_summaries = False
    self.summary_writer_resource = None
    self.scalar_cache = {}
    self.ones_rank_cache = _EagerTensorCache()
    self.zeros_cache = _EagerTensorCache()
    self.execution_mode = None


ContextSwitch = collections.namedtuple(
    "ContextSwitch", ["is_building_function", "enter_context_fn"])


# `_ContextSwitchStack` is a `threading.local` to match the semantics of
# ``DefaultGraphStack`, which is also a `threading.local`.
class _ContextSwitchStack(threading.local):
  """A thread-local stack of context switches."""

  def __init__(self, eager):
    super(_ContextSwitchStack, self).__init__()
    self.stack = []
    if eager:
      # Initialize the stack with a pointer to enter the eager context; this
      # ensures that the fact that eager execution was enabled is propagated
      # across threads, since (1) `enable_eager_execution` modifies a
      # process-level flag (`default_execution_mode`) and (2) `__init__` is
      # called each time a threading.local object is used in a separate thread.
      self.push(is_building_function=False, enter_context_fn=eager_mode)

  def push(self, is_building_function, enter_context_fn):
    """Push metadata about a context switch onto the stack.

    A context switch can take one of two forms: installing a graph as the
    default graph, or entering the eager context. For each context switch,
    we record whether or not the entered context is building a function.

    Args:
      is_building_function: (bool.) Whether the context is building a function.
      enter_context_fn: (function.) A callable that executes the context switch.
        For example, `graph.as_default` or `eager_mode`.
    """

    self.stack.append(
        ContextSwitch(is_building_function, enter_context_fn))

  def pop(self):
    """Pop the stack."""

    self.stack.pop()


# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
  """Environment in which eager operations execute."""

  # TODO(agarwal): create and link in some documentation for `execution_mode`.
  # pylint: disable=redefined-outer-name
  def __init__(self,
               config=None,
               device_policy=None,
               execution_mode=None,
               server_def=None):
    """Creates a new Context.

    Args:
      config: (Optional.) A `ConfigProto` protocol buffer with configuration
        options for the Context. Note that a lot of these options may be
        currently unimplemented or irrelevant when eager execution is enabled.
      device_policy: (Optional.) What policy to use when trying to run an
         operation on a device with inputs which are not on that device.
         When set to None, an appropriate value will be picked automatically.
         The value picked may change between TensorFlow releases.

         Defaults to tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32.
         Valid values:
         - tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is
           not correct.
         - tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
           right device but raises a warning.
         - tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
           hide performance problems.
         - tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
           raising errors on the other ones.
      execution_mode: (Optional.) Policy controlling how operations dispatched
        are actually executed. When set to None, an appropriate value will be
        picked automatically. The value picked may change between TensorFlow
        releases.
        Valid values:
        - tf.contrib.eager.SYNC: executes each operation synchronously.
        - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
          operations may return "non-ready" handles.
      server_def: (Optional.) A tensorflow::ServerDef proto.
        Enables execution on remote devices. GrpcServers need to be started by
        creating an identical server_def to this, and setting the appropriate
        task_indexes, so that the servers can communicate. It will then be
        possible to execute operations on remote devices.

    Raises:
     ValueError: If execution_mode is not valid.
    """
    self._eager_context = _EagerContext()
    self._context_switches = _ContextSwitchStack(self.executing_eagerly())
    self._context_handle = None
    self._context_devices = None
    self._post_execution_callbacks = []
    self._config = config
    self._seed = None
    self._initialize_lock = threading.Lock()
    self._device_policy = device_policy
    if execution_mode not in (None, SYNC, ASYNC):
      raise ValueError(
          "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
    if execution_mode is None:
      execution_mode = SYNC
    self._execution_mode = execution_mode
    self._server_def = server_def

  # pylint: enable=redefined-outer-name

  def _set_global_seed(self, seed):
    """Set a global eager mode seed for random ops."""
    self._seed = seed
    self._rng = random.Random(self._seed)
    # Also clear the kernel cache, to reset any existing seeds
    if self._context_handle is not None:
      pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)

  def _internal_operation_seed(self):
    """Returns a fake operation seed.

      In eager mode, user shouldn't set or depend on operation seed.
      Here, we generate a random seed based on global seed to make
      operation's randomness different and depend on the global seed.

    Returns:
      A fake operation seed based on global seed.
    """
    return self._rng.randint(0, _MAXINT32)

  def _initialize_devices(self):
    """Helper to initialize devices."""
    # Store list of devices
    self._context_devices = []
    device_list = pywrap_tensorflow.TFE_ContextListDevices(
        self._context_handle)
    try:
      self._num_gpus = 0
      for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
        dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
        self._context_devices.append(pydev.canonical_name(dev_name))
        dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
        if dev_type == "GPU":
          self._num_gpus += 1

    finally:
      pywrap_tensorflow.TF_DeleteDeviceList(device_list)

  def _initialize_handle_and_devices(self):
    """Initialize handle and devices."""
    with self._initialize_lock:
      if self._context_handle is not None:
        return
      assert self._context_devices is None
      opts = pywrap_tensorflow.TFE_NewContextOptions()
      try:
        if self._config is not None:
          config_str = self._config.SerializeToString()
          pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
        if self._device_policy is not None:
          pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
              opts, self._device_policy)
        if self._execution_mode == ASYNC:
          pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
        self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
      finally:
        pywrap_tensorflow.TFE_DeleteContextOptions(opts)
      if self._server_def is not None:
        server_def_str = self._server_def.SerializeToString()
        pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600,
                                                  server_def_str)

      self._initialize_devices()

  def _clear_caches(self):
    self.scalar_cache().clear()
    self.ones_rank_cache().flush()
    self.zeros_cache().flush()

  def set_server_def(self, server_def, keep_alive_secs=600):
    """Allow setting a server_def on the context.

    When a server def is replaced, it effectively clears a bunch of caches
    within the context. If you attempt to use a tensor object that was pointing
    to a tensor on the remote device, it will raise an error.

    Args:
      server_def: A tensorflow::ServerDef proto.
        Enables execution on remote devices.
      keep_alive_secs: Num. seconds after which the remote end will hang up.
        As long as the client is still alive, the server state for the context
        will be kept alive. If the client is killed (or there is some failure),
        the server will clean up its context keep_alive_secs after the final RPC
        it receives.

    Raises:
      ValueError: if server_def is None.
    """
    if not server_def:
      raise ValueError("server_def is None.")
    if not self._context_handle:
      self._server_def = server_def
    else:
      server_def_str = server_def.SerializeToString()
      pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
                                                keep_alive_secs, server_def_str)

      # Clear all the caches in case there are remote tensors in them.
      self._clear_caches()

      self._initialize_devices()

  @property
  def _handle(self):
    ctx = self._context_handle
    if ctx is None:
      self._initialize_handle_and_devices()
      return self._context_handle
    else:
      return ctx

  @property
  def _devices(self):
    devices = self._context_devices
    if devices is None:
      self._initialize_handle_and_devices()
      return self._context_devices
    else:
      return devices

  def __str__(self):
    if self._context_handle is None:
      return "Eager TensorFlow Context. Devices currently uninitialized."
    else:
      devices = self._devices
      lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
      for i, d in enumerate(devices):
        lines.append("   Device %d: %s" % (i, d))
      return "\n".join(lines)

  @tf_contextlib.contextmanager
  def _mode(self, mode):
    """A context manager to allow setting the mode to EAGER/GRAPH."""
    ctx = self._eager_context
    old_mode = ctx.mode
    old_is_eager = ctx.is_eager
    ctx.mode = mode
    ctx.is_eager = mode == EAGER_MODE
    if mode == EAGER_MODE:
      # Entering graph mode does not provide us with sufficient information to
      # record a context switch; graph-based context switches are only logged
      # when a graph is registered as the default graph.
      self.context_switches.push(False, eager_mode)
    try:
      yield
    finally:
      ctx.is_eager = old_is_eager
      ctx.mode = old_mode
      if mode == EAGER_MODE:
        self.context_switches.pop()

  def executing_eagerly(self):
    """Returns True if current thread has eager executing enabled."""
    return self._eager_context.is_eager

  def scalar_cache(self):
    """Per-device cache for scalars."""
    return self._eager_context.scalar_cache

  def ones_rank_cache(self):
    """Per-device cache for scalars."""
    return self._eager_context.ones_rank_cache

  def zeros_cache(self):
    """Per-device cache for scalars."""
    return self._eager_context.zeros_cache

  @property
  def scope_name(self):
    """Returns scope name for the current thread."""
    return self._eager_context.scope_name

  @scope_name.setter
  def scope_name(self, s):
    """Sets scope name for the current thread."""
    self._eager_context.scope_name = s

  @property
  def summary_writer_resource(self):
    """Returns summary writer resource."""
    return self._eager_context.summary_writer_resource

  @summary_writer_resource.setter
  def summary_writer_resource(self, resource):
    """Sets summary writer resource."""
    self._eager_context.summary_writer_resource = resource

  @property
  def device_name(self):
    """Returns the device name for the current thread."""
    return self._eager_context.device_name

  @property
  def device_spec(self):
    """Returns the device spec for the current thread."""
    return self._eager_context.device_spec

  @tf_contextlib.contextmanager
  def device(self, name):
    """Context-manager to force placement of operations and Tensors on a device.

    Args:
      name: Name of the device or None to get default placement.

    Yields:
      Nothing.

    Raises:
      ValueError: If name is not a string or is an invalid device name.
    """
    eager_context = self._eager_context
    old_device_name = eager_context.device_name
    old_device_spec = eager_context.device_spec
    cache_key = (old_device_name, name)
    try:
      new_device_name, new_device_spec = _device_parsing_cache[cache_key]
    except TypeError:
      # Error while trying to compute the cache key.
      raise ValueError("Expecting a string device name. Got %s(%s)" %
                       (type(name), name))
    except KeyError:
      # Handle a cache miss.
      if name is not None:
        if not isinstance(name, str):
          raise ValueError("Expecting a string device name. Got %s(%s)" %
                           (type(name), name))
        device_spec = pydev.DeviceSpec.from_string(name)
        if old_device_name:
          new_device_spec = copy.copy(old_device_spec)
        else:
          new_device_spec = pydev.DeviceSpec.from_string(
              "/job:localhost/replica:0/task:0/device:CPU:0")
        new_device_spec.merge_from(device_spec)
      else:
        new_device_spec = pydev.DeviceSpec.from_string("")
      new_device_name = new_device_spec.to_string()
      _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)

    try:
      eager_context.device_name = new_device_name
      eager_context.device_spec = new_device_spec
      yield
    finally:
      eager_context.device_name = old_device_name
      eager_context.device_spec = old_device_spec

  def devices(self):
    """List of the names of devices available to execute operations."""
    return self._devices

  def get_execution_mode(self):
    mode = self._eager_context.execution_mode
    if mode is None:
      mode = self._execution_mode
    return mode

  def set_execution_mode(self, mode):
    """Sets execution mode for current thread."""
    if mode not in (None, SYNC, ASYNC):
      raise ValueError(
          "Execution mode should be None/SYNC/ASYNC. Got %s" % mode)
    if mode is None:
      mode = SYNC
    self._eager_context.execution_mode = mode
    pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, mode == ASYNC)

  @tf_contextlib.contextmanager
  def execution_mode(self, mode):
    """Context manager for setting execution mode for current thread."""
    old_mode = self.get_execution_mode()
    try:
      self.set_execution_mode(mode)
      yield
    finally:
      self.set_execution_mode(old_mode)

  def async_wait(self):
    """Waits for ops dispatched in ASYNC mode to finish."""
    pywrap_tensorflow.TFE_ContextAsyncWait(self._handle)

  def async_clear_error(self):
    """Clears errors raised during ASYNC execution."""
    pywrap_tensorflow.TFE_ContextAsyncClearError(self._handle)

  def num_gpus(self):
    """The number of GPUs available to execute operations."""
    self._initialize_handle_and_devices()
    return self._num_gpus

  def add_function(self, fn):
    """Add a function definition to the context.

    Once added, the function (identified by its name) can be executed like any
    other operation.

    Args:
      fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
    """
    pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)

  def add_function_def(self, fdef):
    """Add a function definition to the context.

    Once added, the function (identified by its name) can be executed like any
    other operation.

    Args:
      fdef: A FunctionDef protocol buffer message.
    """
    fdef_string = fdef.SerializeToString()
    pywrap_tensorflow.TFE_ContextAddFunctionDef(
        self._handle, fdef_string, len(fdef_string))

  def add_post_execution_callback(self, callback):
    """Add a post-execution callback to the context.

    A post-execution callback is invoked immediately after an eager operation or
    function has finished execution, providing access to the op's type, name
    input and output tensors. Multiple execution callbacks can be added, in
    which case the callbacks will be invoked in the order in which they are
    added.

    Args:
      callback: a callable of the signature
      `f(op_type, op_name, attrs, inputs, outputs)`.
      `op_type` is the type of the operation that was just executed (e.g.,
        `MatMul`).
      `op_name` is the name of the operation that has was just executed. This
        name is set by the client who created the operation and can be `None` if
        it is unset.
      `attrs` contains the attributes of the operation as a `tuple` of
        alternating attribute names and attribute values.
      `inputs` is the `list` of input `Tensor`(s) to the op.
      `outputs` is the `list` of output `Tensor`(s) from the op.
       Return value(s) from the callback are ignored.
    """
    # TODO(cais): (b/64674139) Allow access to function-internal operations.
    self._post_execution_callbacks.append(callback)

  def clear_post_execution_callbacks(self):
    """Clear all post-execution callbacks added to the context."""
    del self._post_execution_callbacks[:]

  @property
  def post_execution_callbacks(self):
    """Get the list of post-execution callbacks added to the context."""
    return self._post_execution_callbacks

  def enable_run_metadata(self):
    """Enables tracing of op execution via RunMetadata.

    To retrieve the accumulated metadata call context.export_run_metadata()
    and to stop tracing call context.disable_run_metadata().
    """
    pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)

  @tf_contextlib.contextmanager
  def device_policy(self, policy):
    handle = self._handle
    old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle)
    pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
        handle, policy)
    try:
      yield
    finally:
      pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
          handle, old)

  def disable_run_metadata(self):
    """Disables tracing of op execution via RunMetadata."""
    if not self._context_handle:
      return
    pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)

  def export_run_metadata(self):
    """Returns a RunMetadata proto with accumulated information.

    The returned protocol buffer contains information since the most recent call
    to either enable_run_metadata or export_run_metadata.

    Returns:
      A RunMetadata protocol buffer. Or None if not enabled.
    """
    if not self._context_handle:
      return None
    with c_api_util.tf_buffer() as buffer_:
      pywrap_tensorflow.TFE_ContextExportRunMetadata(
          self._context_handle, buffer_)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    run_metadata = config_pb2.RunMetadata()
    run_metadata.ParseFromString(compat.as_bytes(proto_data))
    return run_metadata

  @property
  def context_switches(self):
    """Returns a stack of context switches."""
    return self._context_switches

  def start_step(self):
    pywrap_tensorflow.TFE_ContextStartStep(self._handle)

  def end_step(self):
    pywrap_tensorflow.TFE_ContextEndStep(self._handle)

_context = None
_context_lock = threading.Lock()


def _initialize_context():
  global _context
  with _context_lock:
    if _context is None:
      _context = Context()


def context():
  """Returns a singleton context object."""
  if _context is None:
    _initialize_context()
  return _context


def context_safe():
  """Returns current context (or None if one hasn't been initialized)."""
  return _context


def set_global_seed(seed):
  """Sets the eager mode seed."""
  context()._set_global_seed(seed)  # pylint: disable=protected-access


def global_seed():
  """Returns the eager mode seed."""
  return context()._seed  # pylint: disable=protected-access


def internal_operation_seed():
  """Returns the operation seed generated based on global seed."""
  return context()._internal_operation_seed()  # pylint: disable=protected-access


@tf_export("executing_eagerly")
def executing_eagerly():
  """Returns True if the current thread has eager execution enabled.

  Eager execution is typically enabled via `tf.enable_eager_execution`,
  but may also be enabled within the context of a Python function via
  tf.contrib.eager.py_func.
  """
  return context().executing_eagerly()


def in_eager_mode():
  """Use executing_eagerly() instead. This function will be removed."""
  return executing_eagerly()


def graph_mode():
  """Context-manager to disable eager execution for the current thread."""
  return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access


def eager_mode():
  """Context-manager to enable eager execution for the current thread."""
  return context()._mode(EAGER_MODE)  # pylint: disable=protected-access


# TODO(agarwal): get rid of this and use ops.name_scope instead.
@contextlib.contextmanager
def namescope(name):
  """ContextManager for creating hierarchical name scopes."""
  ctx = context()
  old_name = ctx.scope_name
  ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
  try:
    yield
  finally:
    ctx.scope_name = old_name


def scope_name():
  """Name of the current scope."""
  return context().scope_name


def device(name):
  """Context-manager to force placement of operations and Tensors on a device.

  Example:
  ```python
  with tfe.device('gpu:0'):
    with tfe.device('cpu:0'):
      shape = tf.constant([], dtype=tf.int32)
    x = tf.truncated_normal(shape, tf.float32)
  ```
  will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
  operation runs on GPU 0.

  Args:
    name: Name of the device (see context().devices()), or None to
      perform automatic placement.

  Returns:
    Context manager for setting the device.
  """
  return context().device(name)


def list_devices():
  """List the names of the available devices.

  Returns:
    Names of the available devices, as a `list`.
  """
  return context().devices()


def set_execution_mode(mode):
  """Sets execution mode for the current thread."""
  context().set_execution_mode(mode)


def execution_mode(mode):
  """Context manager for setting execution mode for current thread."""
  return context().execution_mode(mode)


def async_wait():
  """Waits for ops dispatched in ASYNC mode to finish."""
  return context().async_wait()


def async_clear_error():
  """Clears errors raised during ASYNC execution mode."""
  return context().async_clear_error()


def num_gpus():
  """Get the number of available GPU devices.

  Returns:
    The number of available GPU devices.
  """
  return context().num_gpus()


def enable_run_metadata():
  """Enables tracing of op execution via RunMetadata.

  To retrieve the accumulated metadata call context.export_run_metadata()
  and to stop tracing call context.disable_run_metadata().
  """
  context().enable_run_metadata()


def disable_run_metadata():
  """Disables tracing of op execution via RunMetadata."""
  context().disable_run_metadata()


def export_run_metadata():
  """Returns a RunMetadata proto with accumulated information.

  The returned protocol buffer contains information since the most recent call
  to either enable_run_metadata or export_run_metadata.

  Returns:
    A RunMetadata protocol buffer.
  """
  return context().export_run_metadata()


def set_server_def(server_def):
  context().set_server_def(server_def)


# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file.  Note that IS_IN_GRAPH_MODE and
# in_graph_mode are both parameterless functions.
def _tmp_in_graph_mode():
  return not executing_eagerly()


is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode