aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/specs/python/specs.py
blob: d5223b9b551bc5cb1e007c97820b9425ea13334b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builder for TensorFlow models specified using specs_ops.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six import exec_
from tensorflow.contrib.specs.python import params_ops
from tensorflow.contrib.specs.python import specs_lib
from tensorflow.contrib.specs.python import specs_ops
from tensorflow.python.util import tf_inspect


def eval_params(params, environment=None):
  """Evaluates a parameter specification and returns the environment.

  Args:
      params: parameter assignments as a string
      environment: a dictionary of input bindings

  Returns:
      Environment with additional bindings created by
      executing `params`

  Raises:
      Exception: other exceptions raised during execution of `params`
  """
  specs_lib.check_keywords(params)
  bindings = {}
  if environment:
    bindings.update(environment)
  exec_(params, vars(params_ops), bindings)  # pylint: disable=exec-used
  return bindings


def eval_spec(spec, environment=None):
  """Evaluates a spec and returns the environment.

  This function allows you to use a spec to obtain multiple bindings
  in an environment. That is useful if you use the spec language to
  specify multiple components of a larger network, for example: "left
  = Cr(64, [5,5]); right = Fc(64)" Usually, you will want to use
  `create_net` or `create_net_fun` below.

  Args:
      spec: specification as a string
      environment: a dictionary of input bindings

  Returns:
      Environment with additional bindings created by spec.

  Raises:
      Exception: other exceptions raised during execution of `spec`

  """
  specs_lib.check_keywords(spec)
  bindings = {}
  if environment:
    bindings.update(environment)
  exec_(spec, vars(specs_ops), bindings)  # pylint: disable=exec-used
  return bindings


def create_net_fun(spec, environment=None):
  """Evaluates a spec and returns the binding of `net`.

  Specs are written in a DSL based on function composition.  A spec
  like `net = Cr(64, [3, 3])` assigns an object that represents a
  single argument function capable of creating a network to
  the variable `net`.

  Args:
      spec: specification as a string, ending with a `net = ...` statement
      environment: a dictionary of input bindings

  Returns:
      A callable that instantiates the `net` binding.

  Raises:
      ValueError: spec failed to create a `net`
      Exception: other exceptions raised during execution of `spec`

  """
  bindings = eval_spec(spec, environment)
  net = bindings.get("net", None)
  if net is None:
    raise ValueError("spec failed to create 'net': %s" % (spec,))
  return net.funcall


def create_net(spec, inputs, environment=None):
  """Evaluates a spec and creates a network instance given the inputs.

  Args:
      spec: specification as a string, ending with a `net = ...` statement
      inputs: input that `net` is applied to
      environment: a dictionary of input bindings

  Returns:
      A callable that instantiates the `net` binding.

  Raises:
      ValueError: spec failed to create a `net`
      Exception: other exceptions raised during execution of `spec`
  """
  return create_net_fun(spec, environment)(inputs)


class LocalImport(object):
  """A class that allows us to temporarily import something.

  Attributes:
      frame: the frame in which the context manager was invocked
      names: a dictionary containing the new bindings
      old: variable bindings that have been shadowed by the import
  """

  def __init__(self, names):
    """Create a context manager that binds the names in values.

    Args:
        names: A dictionary or module containing the bindings.
    """
    if not isinstance(names, dict):
      names = vars(names)
    self.names = names

  def __enter__(self):
    self.frame = tf_inspect.currentframe()
    bindings = self.frame.f_back.f_globals
    self.old = {k: bindings.get(k, None) for k in self.names.keys()}
    bindings.update(self.names)

  def __exit__(self, some_type, value, traceback):
    del some_type, value, traceback
    bindings = self.frame.f_back.f_globals
    bindings.update(self.old)
    for k, v in self.old.items():
      if v is None:
        del bindings[k]
    del self.frame


ops = LocalImport(specs_ops)