1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builder for TensorFlow models specified using specs_ops.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
from six import exec_
from tensorflow.contrib.specs.python import params_ops
from tensorflow.contrib.specs.python import specs_lib
from tensorflow.contrib.specs.python import specs_ops
def eval_params(params, environment=None):
"""Evaluates a parameter specification and returns the environment.
Args:
params: parameter assignments as a string
environment: a dictionary of input bindings
Returns:
Environment with additional bindings created by
executing `params`
Raises:
Exception: other exceptions raised during execution of `params`
"""
specs_lib.check_keywords(params)
bindings = {}
if environment: bindings.update(environment)
exec_(params, vars(params_ops), bindings) # pylint: disable=exec-used
return bindings
def eval_spec(spec, environment=None):
"""Evaluates a spec and returns the environment.
This function allows you to use a spec to obtain multiple bindings
in an environment. That is useful if you use the spec language to
specify multiple components of a larger network, for example: "left
= Cr(64, [5,5]); right = Fc(64)" Usually, you will want to use
`create_net` or `create_net_fun` below.
Args:
spec: specification as a string
environment: a dictionary of input bindings
Returns:
Environment with additional bindings created by spec.
Raises:
Exception: other exceptions raised during execution of `spec`
"""
specs_lib.check_keywords(spec)
bindings = {}
if environment: bindings.update(environment)
exec_(spec, vars(specs_ops), bindings) # pylint: disable=exec-used
return bindings
def create_net_fun(spec, environment=None):
"""Evaluates a spec and returns the binding of `net`.
Specs are written in a DSL based on function composition. A spec
like `net = Cr(64, [3, 3])` assigns an object that represents a
single argument function capable of creating a network to
the variable `net`.
Args:
spec: specification as a string, ending with a `net = ...` statement
environment: a dictionary of input bindings
Returns:
A callable that instantiates the `net` binding.
Raises:
ValueError: spec failed to create a `net`
Exception: other exceptions raised during execution of `spec`
"""
bindings = eval_spec(spec, environment)
net = bindings.get("net", None)
if net is None:
raise ValueError("spec failed to create 'net': %s" % (spec,))
return net.funcall
def create_net(spec, inputs, environment=None):
"""Evaluates a spec and creates a network instance given the inputs.
Args:
spec: specification as a string, ending with a `net = ...` statement
inputs: input that `net` is applied to
environment: a dictionary of input bindings
Returns:
A callable that instantiates the `net` binding.
Raises:
ValueError: spec failed to create a `net`
Exception: other exceptions raised during execution of `spec`
"""
return create_net_fun(spec, environment)(inputs)
class LocalImport(object):
"""A class that allows us to temporarily import something.
Attributes:
frame: the frame in which the context manager was invocked
names: a dictionary containing the new bindings
old: variable bindings that have been shadowed by the import
"""
def __init__(self, names):
"""Create a context manager that binds the names in values.
Args:
names: A dictionary or module containing the bindings.
"""
if not isinstance(names, dict):
names = vars(names)
self.names = names
def __enter__(self):
self.frame = inspect.currentframe()
bindings = self.frame.f_back.f_globals
self.old = {k: bindings.get(k, None) for k in self.names.keys()}
bindings.update(self.names)
def __exit__(self, some_type, value, traceback):
del some_type, value, traceback
bindings = self.frame.f_back.f_globals
bindings.update(self.old)
for k, v in self.old.items():
if v is None: del bindings[k]
del self.frame
ops = LocalImport(specs_ops)
|