1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Imperative mode for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.imperative import imperative_graph
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
class ImperativeMode(object):
"""Imperative mode execution of TensorFlow graphs.
This class is a container for an ImperativeGraph, a session, and other
context managers that enable imperative mode execution. The following is
the common usage pattern:
```python
server = tf.train.Server.create_local_server()
with ImperativeMode(server.target):
a = tf.random_normal([])
b = tf.random_normal([])
c = a + b
c_val = c.value
d = c + 1.0
d_val = d.value
# Expect d_val == c_val + 1.0
```
ImperativeMode provides the illusion of immediate execution. It still
constructs a graph and defers op execution. But when an op executes for
the first time, its results are cached and the cached value is returned for
future executions. The __exit__ method clears this graph and cached values.
To use ImperativeMode inside a loop, the `new_step` method can be used to
create a temporary context around the loop body to clear the cache at loop
exit as follows:
```python
server = tf.train.Server.create_local_server()
with ImperativeMode(server.target) as mode:
w = tf.get_variable('w', [])
for i in range(10):
with mode.new_step():
x = tf.random_uniform([])
y = tf.random_uniform([])
z = w.assign_add(x + y)
print(z.value)
```
ImperativeMode graph does not support all TensorFlow operations and features.
Here are the current known limitations of ImperativeMode :
* Stateful operations returned ref-typed tensors are limited to
TensorFlow Variables and the associated operations. Data structures such as
queues barriers, etc. are not supported in ImperativeMode.
* Variables created and managed via `tf.variable_scope` and the associated
`tf.get_variable` are not supported. (These use auxiliary data structures in
addition to the graph, which are not aware of the imperative mode execution.)
TODO(keveman): Remove the above restrictions on ImperativeMode.
"""
def __init__(self, target, parent_graph=None):
"""Initializes an ImperativeMode.
Args:
target: The TensorFlow execution engine to connect to.
parent_graph: (Optional) An ImperativeGraph.
Raises:
UnimplementedError: if non-None parent_graph is not an ImperativeGraph.
"""
self._target = target
self._parent_graph = parent_graph
# Create a new graph
self._graph = imperative_graph.ImperativeGraph(
parent_graph=self._parent_graph)
self._default_graph = self._graph.as_default()
# Context manager to record variable inits
self._record_variable_inits = self._graph.record_variable_inits()
if self._parent_graph:
if not isinstance(self._parent_graph, imperative_graph.ImperativeGraph):
raise errors.UnimplementedError(None, None, 'ImperativeMode needs an '
'ImperativeGraph')
# Clone the `_parent_graph` in to the current graph. This is so that
# operations used from the enclosing ImperativeMode context are
# available in the current context.
with self._graph.as_default(), self._graph.return_as_is():
importer.import_graph_def(self._parent_graph.as_graph_def(), name='')
self._session = session.Session(graph=self._graph, target=self._target)
# Override the `_session`'s run, so that variable inits can be
# called before the actual run.
self._old_run = self._session.run
self._session.run = self.run
self._context_managers = [
self._session.as_default(),
self._default_graph,
self._record_variable_inits,
imperative_graph.add_session_attr(ops.Tensor, self._session)]
def run(self, *args, **kwargs):
"""Runs the variable init ops before calling the original run method."""
self._graph.run_pending_inits(self._session)
ret = self._old_run(*args, **kwargs)
return ret
def __enter__(self):
"""Enters the runtime contexts of the `_context_managers`."""
for c in self._context_managers:
c.__enter__()
return self
def __exit__(self, exec_type, exec_value, exec_tb):
"""Cleans up resources, exits the runtime contexts in reverse order."""
# pylint: disable=protected-access
if self._graph._variable_cleanup_ops:
self._session.run(self._graph._variable_cleanup_ops)
# pylint: enable=protected-access
self._session.close()
for c in reversed(self._context_managers):
c.__exit__(exec_type, exec_value, exec_tb)
def new_step(self):
"""Returns a new 'child' ImperativeMode.
`new_step` enables running the imperative mode inside a Python loop. The
ImperativeGraph object and the tensors created and cached during the
execution of that graph are destroyed when the context entered with the
object returned from this function is 'exited'. However, the operations
in `self._graph` and any of its ancestors can be freely used as
operands to operations in the graph contained in the object returned
by this function.
Returns:
A new ImperativeMode object.
"""
self._graph.run_pending_inits(self._session)
return ImperativeMode(self._target, parent_graph=self._graph)
|