From 4fe49431eb892184c03e4af57d9a0b36b1af6989 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Oct 2016 13:34:23 -0800 Subject: Adding specs to the open source version of TensorFlow. Change: 135406004 --- tensorflow/contrib/specs/BUILD | 60 +++++ tensorflow/contrib/specs/README.md | 263 +++++++++++++++++++ tensorflow/contrib/specs/python/__init__.py | 0 tensorflow/contrib/specs/python/params_ops.py | 87 +++++++ tensorflow/contrib/specs/python/specs.py | 157 +++++++++++ tensorflow/contrib/specs/python/specs_lib.py | 289 +++++++++++++++++++++ tensorflow/contrib/specs/python/specs_ops.py | 245 ++++++++++++++++++ tensorflow/contrib/specs/python/specs_test.py | 231 +++++++++++++++++ tensorflow/contrib/specs/python/summaries.py | 301 ++++++++++++++++++++++ tensorflow/contrib/specs/python/summaries_test.py | 80 ++++++ 10 files changed, 1713 insertions(+) create mode 100644 tensorflow/contrib/specs/BUILD create mode 100644 tensorflow/contrib/specs/README.md create mode 100644 tensorflow/contrib/specs/python/__init__.py create mode 100644 tensorflow/contrib/specs/python/params_ops.py create mode 100644 tensorflow/contrib/specs/python/specs.py create mode 100644 tensorflow/contrib/specs/python/specs_lib.py create mode 100644 tensorflow/contrib/specs/python/specs_ops.py create mode 100644 tensorflow/contrib/specs/python/specs_test.py create mode 100644 tensorflow/contrib/specs/python/summaries.py create mode 100644 tensorflow/contrib/specs/python/summaries_test.py (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD new file mode 100644 index 0000000000..517fe4784c --- /dev/null +++ b/tensorflow/contrib/specs/BUILD @@ -0,0 +1,60 @@ +# Description: +# A small domain-specific language (DSL) for defining deep learning networks. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "specs", + srcs = [ + "python/__init__.py", + "python/params_ops.py", + "python/specs.py", + "python/specs_lib.py", + "python/specs_ops.py", + "python/summaries.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/ndlstm", + "//tensorflow/python:framework", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + ], +) + +tf_py_test( + name = "specs_test", + srcs = ["python/specs_test.py"], + additional_deps = [ + ":specs", + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "summaries_test", + srcs = ["python/summaries_test.py"], + additional_deps = [ + ":specs", + "//tensorflow:tensorflow_py", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/specs/README.md b/tensorflow/contrib/specs/README.md new file mode 100644 index 0000000000..fcd008e81b --- /dev/null +++ b/tensorflow/contrib/specs/README.md @@ -0,0 +1,263 @@ +# specs -- simple specifications for TensorFlow networks + +This library implements a simple domain-specific language for specifying +deep neural networks in TensorFlow. + +From a high level, there are a set of standard operators and ways of +combining them: + + - operator `|` takes the output from one layer and "pipes" it into the next + - operator `**` repeats a layer multiple times + +Naming conventions: + + - single character names are reserved to users + - built-in layers are capitalized, not CamelCase (Relu, Fs, etc.) + - built-in layers that are common are usually two letters (Cr, Fs, etc.) + - less common operations are longer (Relu, Conc, etc.) + - temporary names should end in _ + +Common layers: + +Common layers are defined by short, capitalized abbreviations. For layers +that take an activation function (fully_connected, conv2d), the acronym +is a conjunction of a base layer and the activation. For example, `Fs` +represents a fully connected layer followed by a sigmoid, whereas `Ft` +represents a fully connected layer followed by a Tanh. + + - `Fx` = slim.fully_connected; x = activation function, one of s/t/r/l/m + - `Cx` = slim.conv2d; x = activation function, one of s/t/r/l/m + - `Mp` = slim.max_pool2d + - `Ap` = slim.avg_pool2d + - `Bn` = slim.batch_norm + +Nonlinearities (suffixes for C/F, so Cs = convolutional layer + sigmoid): + + - `s` = sigmoid + - `t` = tanh + - `r` = relu + - `l` = linear (i.e., None) + - `m` = softmax + +Positional and keyword arguments are the same as for the underlying +slim and TensorFlow functions. Therefore, common usage patterns are: + + Cr(64, [5, 5]) # conv2d with a 5x5 footprint and 64 outputs + Mp([2, 2]) # max pooling using [2, 2] steps + +Explicit nonlinearities: + + - `Relu` = tf.nn.relu + - `Sig` = tf.nn.sigmoid + - `Tanh` = tf.nn.tanh + - `Smax` = tf.nn.softmax + +Reshaping: + + - `Flat` = slim.flatten + - `Reshape` = tf.reshape + - `Squeeze` = tf.squeeze + - `Expand` = tf.expand_dims + +Multidimensional LSTM: + +These are intended as alternatives to 2D convolutions. For sequence models, +there will be other modeling primitives. + + - `Lstm2` = Fun(lstm2d.separable_lstm) # 2D-to-2D + - `Lstm2to1` = Fun(lstm2d.reduce_to_sequence) # 2D-to-1D + - `Lstm2to0` = Fun(lstm2d.reduce_to_final) # 2D-to-vector + - `Clstm2(n, m)` is a `Cl(n, [3,3])` followed by `Lstm2(m)` + - `Dws(n)` is a depthwise convolution `Cs(n, [1, 1])` + +Other: + + - `Id` = identity + - `Do` = slim.dropout + - `Lrn` = tf.nn.local_response_normalization + - `Unit` = slim.unit_norm + - `Conc` is roughly tf.nn.concat + +Binding external functions: + + - `External` - import an external function using module path + - `Import` - import an external function using statements + +A network specification is a sequence of `name = expression` Python statements, +with the `net` variable holding the network that is being defined. That is, +your specification must have a statement of the form `net = ...` as its +last statement. + +So, a simple MNIST network might look like: + + net = Cr(64, [5, 5]) | Fs(10) + +More complicated: + + net = (Cr(64, [5, 5]) | Mp([2, 2])) ** 3 | Fs(10) + +With temporary names: + + cmp_ = Cr(64, [5, 5]) | Mp([2, 2]) + net = cmp_ ** 3 | Fs(10) + +(You can also separate statements with `;` instead of `\n`) + +General model structure: + + - Models are sequences of `name = expression` statements + in Python syntax. + - Other kinds of statements are not allowed (with a few + exceptions, like calling `debug()`) + - Names should be assigned only once. + +These constraints are only partially enforced by the library right +now, but may be strictly enforced in the future. + +# More Details + +The spec language is intended for rapid experimentation with common +layer types; it's not a replacement for the standard TensorFlow or +slim APIs. If you have some complex layer type or construct that's +difficult to represent in `spec`, you can implement it directly in +Python and then easily make it available as a `spec` operator. + +Since partial application with positional arguments can be a little +confusing, you can also specify positional arguments with keywords like +`_1`: + + cr5_ = Cr(_1=[5, 5]); net = cr5_(64) ** 3 | Fs(10) + +You can enable debugging by putting `debug()` at the beginning of your network +definition: + + debug(); net = Cr(64, [5, 5]) | Fs(10) + +The module is a "combinator library". To make the syntax work nicely +with Python, the `__call__` operator is overloaded to perform partial +application. + +To create a network from Python, you just call the following: + + inputs = tf.placeholder(...) + spec = "net = (Cr(64, [5, 5]) | Mp([2, 2])) ** 3 | Fs(10)" + outputs = specs.create_net(spec, inputs) + +You can pass variable bindings into `create_net`: + + inputs = tf.placeholder(...) + spec = "net = (Cr(64, [5, 5]) | Mp([2, 2])) ** depth | Fs(10)" + outputs = specs.create_net(spec, inputs, dict(depth=3)) + +# Using `specs` in Code + +The specs operators are defined in the module `specs_ops`. To facilitate +using the `specs` DSL in your code without namespace pollution, you can +use the `specs.ops` context manager, which will temporarily make the +`specs` operators available in your code: + + import tensorflow as tf + import numpy.random as npr + specs = tf.contrib.specs.python + + with specs.ops: + net = (Cr(64, [2, 2]) | Mp([2, 2])) ** 3 | Flat | Fs(10) + inputs = tf.placeholder(tf.float32, [None, 28, 28, 1]) + outputs = net.funcall(inputs) + + sess = tf.InteractiveSession() + tf.initialize_all_variables().run() + sess.run([outputs], feed_dict={inputs: npr.uniform(size=(17, 28, 28, 1))}) + +# Sharing and Variables + +You can share variables among subnets by wrapping them with `Shared`: + + f = Shared(Fr(100)) + g = f | f | f | f + +This will stack four fully connected ReLU layers, sharing the same +weights and biases. + +You can also create variables explicitly: + + v = Var("v") + +You can use this to write expressions like this: + + net = Cl(100, 3) + Var("b", shape=[128, 128, 100])) + +Note that, under the covers, both the `Cl` operator and the `Var` operator +generate functions that are eventually applied via `funcall` to an input +tensor; the function generated by the `Var` operator ignores its argument +and calls `tf.get_variable` with the supplied arguments. + +# Pulling in New Primitives + +If you need some special function in your spec language, you can make +it available using `External` or `Import`. The following two statements +are equivalent: + + Sig = External("some_module", "some_op") + Sig = Import("import tensorflow as tf; f = tf.nn.sigmoid") + +You probably will want to use `Import` because TensorFlow contains a +number of imports that look like they are in modules, but they are +actually just values placed in the namespace somehow. The `Import` +function takes an arbitrary Python statement that eventually needs to +assign a value to the variable `f` that is then wrapped up as a function. + +# AutoFunction + +Not all functions available in TensorFlow have corresponding short names +in specs; in order to access other functions conveniently, you can refer +to any function in a number of existing modules using the full function +name in that module. Module shortcuts are defined for: + + - `TF` = `tf` + - `NN` = `tf.nn` + - `SL` = `slim` + +You can, of course, introduce more abbreviations in your own code. + +These are defined as: + + TF = specs_lib.AutoFunction(tf) + +Using these definitions, `SL.conv2d(64, 5)` is equivalent to `Cr(64, 5)`. + +# Summaries + +There are a number of functions that give you information about the structure +of specs (and other, similarly structured, TensorFlow graphs); the first +number is the number of parameters, followed by the op, and the shape. + + >>> summaries.tf_spec_summary("net = Cr(100, [3,3]) | Flat | Fs(10)", + input_shape=(17, 28, 28, 1)) + 0 Placeholder [17, 28, 28, 1] + 1000 Conv [17, 28, 28, 100] + 0 Flatten [17, 78400] + 784010 fully_connected [17, 10] + >>> + +# ToDo + +More documentation, comments. + +The following features are intended to be added soon (names subject to change): + + - add sequence processing layers + - add named point cuts + - Seq(a, b, c).add(name=layer).add(name=layer) for explicit seq. structures + - S2d, D2s (space/depth operators) + - `Shared(...)` -- variable sharing + - `Mix(...)` -- weighted convex combination of layer outputs + - `Lincom(...)` -- weighted linear combination of layer outputs + - `SameDepth(A)` -- makes output depth same as input + - summary ops + - slim's `arg_scope` + - automatic wrapping of long-name slim layers + - depth-to-space, etc. + +Eventually, there may be a similar spec language for +input layers and pipelines. diff --git a/tensorflow/contrib/specs/python/__init__.py b/tensorflow/contrib/specs/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow/contrib/specs/python/params_ops.py b/tensorflow/contrib/specs/python/params_ops.py new file mode 100644 index 0000000000..7cabdfd9b8 --- /dev/null +++ b/tensorflow/contrib/specs/python/params_ops.py @@ -0,0 +1,87 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators for concise TensorFlow parameter specifications. + +This module is used as an environment for evaluating expressions +in the "params" DSL. + +Specifications are intended to assign simple numerical +values. Examples: + + --params "n=64; d=5" --spec "(Cr(n) | Mp([2, 2])) ** d | Fm" + +The random parameter primitives are useful for running large numbers +of experiments with randomly distributed parameters: + + --params "n=Li(5,500); d=Ui(1,5)" --spec "(Cr(n) | Mp([2, 2])) ** d | Fm" + +Internally, this might be implemented as follows: + + params = specs.create_params(FLAGS.params, {}) + logging.info(repr(params)) + net = specs.create_net(FLAGS.spec, inputs, params) + +Note that separating the specifications into parameters and network +creation allows us to log the random parameter values easily. + +The implementation of this will change soon in order to support +hyperparameter tuning with steering. Instead of returning a number, +the primitives below will return a class instance that is then +used to generate a random number by the framework. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Lint disabled because these are operators in the DSL, not regular +# Python functions. +# pylint: disable=invalid-name +# pylint: disable=wildcard-import,unused-wildcard-import,redefining-builtin +# pylint: disable=redefined-builtin,g-importing-member,no-member +# make available all math expressions +import math +from math import * +import random +# pylint: enable=wildcard-import,unused-wildcard-import,redefining-builtin +# pylint: enable=redefined-builtin,g-importing-member,no-member + + +def Uf(lo=0.0, hi=1.0): + """Uniformly distributed floating number.""" + return random.uniform(lo, hi) + + +def Ui(lo, hi): + """Uniformly distributed integer, inclusive limits.""" + return random.randint(lo, hi) + + +def Lf(lo, hi): + """Log-uniform distributed floatint point number.""" + return math.exp(random.uniform(math.log(lo), math.log(hi))) + + +def Li(lo, hi): + """Log-uniform distributed integer, inclusive limits.""" + return int(math.floor(math.exp(random.uniform(math.log(lo), + math.log(hi+1-1e-5))))) + + +def Nt(mu, sigma, limit=3.0): + """Normally distributed floating point number with truncation.""" + return min(max(random.gauss(mu, sigma), mu-limit*sigma), mu+limit*sigma) + + +# pylint: enable=invalid-name diff --git a/tensorflow/contrib/specs/python/specs.py b/tensorflow/contrib/specs/python/specs.py new file mode 100644 index 0000000000..a9fba442db --- /dev/null +++ b/tensorflow/contrib/specs/python/specs.py @@ -0,0 +1,157 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builder for TensorFlow models specified using specs_ops. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import inspect + +from six import exec_ +from tensorflow.contrib.specs.python import params_ops +from tensorflow.contrib.specs.python import specs_lib +from tensorflow.contrib.specs.python import specs_ops + + +def eval_params(params, environment=None): + """Evaluates a parameter specification and returns the environment. + + Args: + params: parameter assignments as a string + environment: a dictionary of input bindings + + Returns: + Environment with additional bindings created by + executing `params` + + Raises: + Exception: other exceptions raised during execution of `params` + """ + specs_lib.check_keywords(params) + bindings = {} + if environment: bindings.update(environment) + exec_(params, vars(params_ops), bindings) # pylint: disable=exec-used + return bindings + + +def eval_spec(spec, environment=None): + """Evaluates a spec and returns the environment. + + This function allows you to use a spec to obtain multiple bindings + in an environment. That is useful if you use the spec language to + specify multiple components of a larger network, for example: "left + = Cr(64, [5,5]); right = Fc(64)" Usually, you will want to use + `create_net` or `create_net_fun` below. + + Args: + spec: specification as a string + environment: a dictionary of input bindings + + Returns: + Environment with additional bindings created by spec. + + Raises: + Exception: other exceptions raised during execution of `spec` + + """ + specs_lib.check_keywords(spec) + bindings = {} + if environment: bindings.update(environment) + exec_(spec, vars(specs_ops), bindings) # pylint: disable=exec-used + return bindings + + +def create_net_fun(spec, environment=None): + """Evaluates a spec and returns the binding of `net`. + + Specs are written in a DSL based on function composition. A spec + like `net = Cr(64, [3, 3])` assigns an object that represents a + single argument function capable of creating a network to + the variable `net`. + + Args: + spec: specification as a string, ending with a `net = ...` statement + environment: a dictionary of input bindings + + Returns: + A callable that instantiates the `net` binding. + + Raises: + ValueError: spec failed to create a `net` + Exception: other exceptions raised during execution of `spec` + + """ + bindings = eval_spec(spec, environment) + net = bindings.get("net", None) + if net is None: + raise ValueError("spec failed to create 'net': %s" % (spec,)) + return net.funcall + + +def create_net(spec, inputs, environment=None): + """Evaluates a spec and creates a network instance given the inputs. + + Args: + spec: specification as a string, ending with a `net = ...` statement + inputs: input that `net` is applied to + environment: a dictionary of input bindings + + Returns: + A callable that instantiates the `net` binding. + + Raises: + ValueError: spec failed to create a `net` + Exception: other exceptions raised during execution of `spec` + """ + return create_net_fun(spec, environment)(inputs) + + +class LocalImport(object): + """A class that allows us to temporarily import something. + + Attributes: + frame: the frame in which the context manager was invocked + names: a dictionary containing the new bindings + old: variable bindings that have been shadowed by the import + """ + + def __init__(self, names): + """Create a context manager that binds the names in values. + + Args: + names: A dictionary or module containing the bindings. + """ + if not isinstance(names, dict): + names = vars(names) + self.names = names + + def __enter__(self): + self.frame = inspect.currentframe() + bindings = self.frame.f_back.f_globals + self.old = {k: bindings.get(k, None) for k in self.names.keys()} + bindings.update(self.names) + + def __exit__(self, some_type, value, traceback): + del some_type, value, traceback + bindings = self.frame.f_back.f_globals + bindings.update(self.old) + for k, v in self.old.items(): + if v is None: del bindings[k] + del self.frame + +ops = LocalImport(specs_ops) diff --git a/tensorflow/contrib/specs/python/specs_lib.py b/tensorflow/contrib/specs/python/specs_lib.py new file mode 100644 index 0000000000..e2ddd7567e --- /dev/null +++ b/tensorflow/contrib/specs/python/specs_lib.py @@ -0,0 +1,289 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implement the "specs" DSL for describing deep networks.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import importlib +import operator +import re + +from six import exec_ + +QUOTED = re.compile(r""" +"([^"\\]|\\.)*" | +'([^'\\]|\\.)*' +""", re.VERBOSE) +KEYWORDS = re.compile(r"""\b(import|while|def|exec)\b""") + + +debug_ = False + + +def check_keywords(spec): + """Check for common Python keywords in spec. + + This function discourages the use of complex constructs + in TensorFlow specs; it doesn't completely prohibit them + (if necessary, we could check the AST). + + Args: + spec: spec string + + Raises: + ValueError: raised if spec contains a prohibited keyword. + """ + spec = re.sub(QUOTED, "", spec) + match = re.search(KEYWORDS, spec) + if match: + raise ValueError("keyword '%s' found in spec" % match.group(1)) + + +def get_positional(args, kw, kw_overrides=False): + """Interpolates keyword arguments into argument lists. + + If `kw` contains keywords of the form "_0", "_1", etc., these + are positionally interpolated into the argument list. + + Args: + args: argument list + kw: keyword dictionary + kw_overrides: key/value pairs that override kw + + Returns: + (new_args, new_kw), new argument lists and keyword dictionaries + with values interpolated. + """ + new_kw = {k: v for k, v in kw.items() if k[0] != "_"} + if len(new_kw) == len(kw): + return args, kw + new_args = list(args) + for key, value in kw.items(): + if key[0] != "_": continue + index = int(key[1:]) + while len(new_args) <= index: + new_args += [None] + if kw_overrides or new_args[index] is None: + new_args[index] = value + return new_args, new_kw + + +class Composable(object): + """A composable function. + + This defines the operators common to all composable objects. + Currently defines copmosition (via "|") and repeated application + (via "**"), and maps addition ("+") and multiplication ("*") + as "(f + g)(x) = f(x) + g(x)". + """ + + def __or__(self, f): + return Composition(self, f) + + def __add__(self, g): + return Operator(operator.add, self, g) + + def __mul__(self, g): + return Operator(operator.mul, self, g) + + def __pow__(self, n): + assert n >= 0 + if n == 0: + return Function(lambda x, *args, **kw: x) + result = self + for _ in range(n-1): + result = Composition(result, self) + return result + + +class Callable(Composable): + """A composable function that simply defers to a callable function. + """ + + def __init__(self, f): + self.f = f + + def funcall(self, x): + return self.f(x) + + +class Operator(Composable): + """A wrapper for an operator. + + This takes an operator and an argument list and returns + the result of applying the operator to the results of applying + the functions in the argument list. + """ + + def __init__(self, op, *args): + self.op = op + self.funs = args + + def funcall(self, x): + outputs = [f.funcall(x) for f in self.funs] + return self.op(*outputs) + + +class Function(Composable): + """A composable function wrapper for a regular Python function. + + This overloads the regular __call__ operator for currying, i.e., + arguments passed to __call__ are remembered for the eventual + function application. + + The final function application happens via the `of` method. + """ + + def __init__(self, f, *args, **kw): + if not callable(f): + raise ValueError("%s: is not callable" % f) + self.f = f + self.args = list(args) + self.kw = kw + + def __call__(self, *args, **kw): + new_args = list(args) + self.args + new_kw = self.kw.copy() + new_kw.update(kw) + return Function(self.f, *new_args, **new_kw) + + # TODO(tmb) The `of` method may be renamed to `function`. + def funcall(self, x): + args, kw = get_positional(self.args, self.kw) + if debug_: + print("DEBUG:", self.f, x, args, kw) + return self.f(x, *args, **kw) + + +class AutoFunction(object): + """Automatically curry functions when accessed as attributes. + + This class wraps a dictionary mapping keys to values. When an attribute + is accessed, the class looks up the attribute in the dictionary and + wraps it (curries it) using Function(...). When wrapped around + existing modules implementing TensorFlow functions or layers, this + turns those functions or layers automatically into specs-compatible + layers. + + For example, `net` and `net2` are equivalent: + TF = AutoFunction(tf) + with specs.ops: + net = TF.conv2d(64, 5) ** 3 | Flat + net2 = Cr(64, 5) ** 3 | Flat + + Attributes: + source: A dictionary holding the underlying key-value mappings. + """ + + def __init__(self, source): + """Creates an AutoFunction wrapper for a module. + + Args: + source: A dictionary or a module. + """ + if not isinstance(source, dict): + source = vars(source) + self.source = source + + def __getattr__(self, key): + """Looks up the key in the source dictionary and curries the result. + + Args: + key: The symbol name to look up. + + Returns: + The curried argument. + + Raises: + ValueError: The key does not exist, or it doesn't refer to a callable. + """ + result = self.source.get(key, None) + if result is None: + raise ValueError("%s: no such symbol") + if not callable(result): + raise ValueError("value of %s is not callable (type is %s)" % + (key, type(result))) + return Function(result) + + +class Composition(Composable): + """A function composition. + + This simply composes its two argument functions when + applied to a final argument via `of`. + """ + + def __init__(self, f, g): + self.f = f + self.g = g + + def funcall(self, x): + return self.g.funcall(self.f.funcall(x)) + + +# These are DSL names, not Python names +# pylint: disable=invalid-name, exec-used +def External(module_name, function_name): + """Import a function from an external module. + + Note that the `module_name` must be a module name + that works with the usual import mechanisms. Shorthands + like "tf.nn" will not work. + + Args: + module_name: name of the module + function_name: name of the function within the module + + Returns: + Function-wrapped value of symbol. + """ + module = importlib.import_module(module_name) + return Function(vars(module)[function_name]) + + +def Import(statements): + """Import a function by exec. + + Args: + statements: Python statements + + Returns: + Function-wrapped value of `f`. + + Raises: + ValueError: the statements didn't define a value for "f" + """ + environ = {} + exec_(statements, environ) + if "f" not in environ: + raise ValueError("failed to define \"f\": %s", statements) + f = environ["f"] + return Function(f) + + +# pylint: enable=invalid-name, exec-used +def debug(mode=True): + """Turn on/off debugging mode. + + Debugging mode prints more information about the construction + of a network. + + Args: + mode: True if turned on, False otherwise + """ + global debug_ + debug_ = mode diff --git a/tensorflow/contrib/specs/python/specs_ops.py b/tensorflow/contrib/specs/python/specs_ops.py new file mode 100644 index 0000000000..4811941fd3 --- /dev/null +++ b/tensorflow/contrib/specs/python/specs_ops.py @@ -0,0 +1,245 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators for concise TensorFlow network models. + +This module is used as an environment for evaluating expressions +in the "specs" DSL. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import tensorflow as tf +from tensorflow.contrib.ndlstm.python import lstm1d +from tensorflow.contrib.ndlstm.python import lstm2d +from tensorflow.contrib.specs.python import specs_lib + + +slim = tf.contrib.slim + + +# The following assignments don't appear to follow Google naming +# conventions, but that's because these are functions defined by +# higher-order function application, not "constants" and because they +# are the commands of the DSL. +# pylint: disable=invalid-name + + +class Idx(specs_lib.Composable): + """Implements the identity function in network specifications.""" + + def funcall(self, x): + return x + + +class Conc(specs_lib.Composable): + """Implements tensor concatenation in network specifications.""" + + def __init__(self, dim, *args): + """Concatenates tensors along the given dimension. + + Args: + dim: dimension along which concatenation takes place + *args: argument tensor functions to be concatenated + """ + self.dim = dim + self.funs = args + + def funcall(self, x): + outputs = [f.funcall(x) for f in self.funs] + return tf.concat(self.dim, outputs) + + +External = specs_lib.External +Import = specs_lib.Import +Fun = specs_lib.Function +debug = specs_lib.debug +Print = Fun(tf.Print) +Id = Fun(tf.identity) + +# TODO(tmb) add Assert + +# Two letter names for the most common layers. + +# 2D Convolutional layers with nonlinearities (s/t/r/m/l) +# TODO(tmb) add Cbs, Fbs etc. for batch norms + +Cx = Fun(slim.conv2d) +Cs = Fun(slim.conv2d, activation_fn=tf.nn.sigmoid) +Ct = Fun(slim.conv2d, activation_fn=tf.nn.tanh) +Cr = Fun(slim.conv2d, activation_fn=tf.nn.relu) +Cm = Fun(slim.conv2d, activation_fn=tf.nn.softmax) +Cl = Fun(slim.conv2d, activation_fn=None) + +# Fully connected slim with nonlinearities (s/t/r/m/l) + +Fx = Fun(slim.fully_connected) +Fs = Fun(slim.fully_connected, activation_fn=tf.nn.sigmoid) +Ft = Fun(slim.fully_connected, activation_fn=tf.nn.tanh) +Fr = Fun(slim.fully_connected, activation_fn=tf.nn.relu) +Fm = Fun(slim.fully_connected, activation_fn=tf.nn.softmax) +Fl = Fun(slim.fully_connected, activation_fn=None) + +# Pooling + +Mp = Fun(slim.max_pool2d) +Ap = Fun(slim.avg_pool2d) + +# Batch manipulations + +Do = Fun(slim.dropout) +Bn = Fun(slim.batch_norm) +Lrn = Fun(tf.nn.local_response_normalization) +Unit = Fun(slim.unit_norm) + +# Shape changes + +Flat = Fun(slim.flatten) +Reshape = Fun(tf.reshape) +Transpose = Fun(tf.transpose) +Squeeze = Fun(tf.squeeze) +Expand = Fun(tf.expand_dims) + +# Nonlinearities (rarely needed on their own) + +Relu = Fun(tf.nn.relu) +Sig = Fun(tf.nn.sigmoid) +Tanh = Fun(tf.nn.tanh) +Smax = Fun(tf.nn.softmax) + +# 2D LSTM + +Lstm2 = Fun(lstm2d.separable_lstm) +Lstm2to1 = Fun(lstm2d.reduce_to_sequence) # 2D to 1D +Lstm2to0 = Fun(lstm2d.reduce_to_final) # 2D to depth-only + + +def Clstm2(n, *args, **kw): + """2D LSTM with 3x3 pre-convolution.""" + return Cl(n, [3, 3]) | Lstm2(*args, **kw) + + +def Dws(n): + """Depth-wise convolution + sigmoid (used after LSTM).""" + return Cs(n, [1, 1]) + + +def Dwm(n): + """Depth-wise convolution + softmax (used after LSTM).""" + return Cm(n, [1, 1]) + +# 1D LSTM + +Lstm1 = Fun(lstm1d.ndlstm_base) +Lstm1to0 = Fun(lstm1d.sequence_to_final) # 1D to depth-only +Ssm = Fun(lstm1d.sequence_softmax) + +# Sharing of Variables + + +def Var(name, *args, **kw): + """Implements an operator that generates a variable. + + This function is still experimental. Use it only + for generating a single variable instance for + each name. + + Args: + name: Name of the variable. + *args: Other arguments to get_variable. + **kw: Other keywords for get_variable. + + Returns: + A specs object for generating a variable. + """ + def var(_): + return tf.get_variable(name, *args, **kw) + return specs_lib.Callable(var) + + +class Shared(specs_lib.Composable): + """Wraps a scope with variable reuse around the subnetwork. + + This function is still experimental. + + Attributes: + f: The shared subnetwork. + name: A name for the shared scope. + used: A flag indicating whether the scope has already been used. + """ + + shared_number = 1 + + def __init__(self, subnet, name=None, scope=None): + """Create the Shared operator. + + Use this as: + + f = Shared(Cr(100, 3)) + g = f | f | f + + Ordinarily, you do not need to provide either a name or a scope. + Providing a name is useful if you want a well-defined namespace + for the variables (e.g., for saving a subnet). + + Args: + subnet: Definition of the shared network. + name: Optional name for the shared context. + scope: Optional shared scope (must be a Scope, not a string). + + Raises: + ValueError: Scope is not of type tf.Scope, name is not + of type string, or both scope and name are given together. + """ + if scope is not None and not isinstance(scope, tf.VariableScope): + raise ValueError("scope must be None or a VariableScope") + if name is not None and not isinstance(scope, str): + raise ValueError("name must be None or a string") + if scope is not None and name is not None: + raise ValueError("cannot provide both a name and a scope") + if name is None: + name = "Shared_%d" % Shared.shared_number + Shared.shared_number += 1 + self.subnet = subnet + self.name = name + self.scope = scope + + def funcall(self, x): + """Apply the shared operator to an input. + + This wraps a variable scope around the creation of the subnet. + + Args: + x: The input argument on which the subnet is invoked. + + Returns: + The output tensor from invoking the subnet constructor. + """ + if self.scope is None: + with tf.variable_scope(self.name, values=[x]) as scope: + self.scope = scope + return self.subnet.funcall(x) + else: + with tf.variable_scope(self.scope, values=[x], reuse=True): + return self.subnet.funcall(x) + +# AutoFunction bindings of some existing modules + +TF = specs_lib.AutoFunction(tf) +NN = specs_lib.AutoFunction(tf.nn) +SL = specs_lib.AutoFunction(slim) + +# pylint: enable=invalid-name diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py new file mode 100644 index 0000000000..71e160f092 --- /dev/null +++ b/tensorflow/contrib/specs/python/specs_test.py @@ -0,0 +1,231 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testing specs specifications.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.specs.python import specs +from tensorflow.contrib.specs.python import summaries + + +def _rand(*size): + return np.random.uniform(size=size).astype("f") + + +class SpecsTest(tf.test.TestCase): + + def testSimpleConv(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 18, 19, 5)) + spec = "net = Cr(64, [5, 5])" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 18, 19, 64]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 18, 19, 64)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ var conv var biasadd relu") + + def testUnary(self): + # This is just a quick and dirty check that these ops exist + # and work as unary ops. + with self.test_session(): + inputs = tf.constant(_rand(17, 55)) + spec = "net = Do(0.5) | Bn | Unit(1) | Relu | Sig | Tanh | Smax" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [17, 55]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (17, 55)) + + def testAdd(self): + with self.test_session(): + inputs = tf.constant(_rand(17, 55)) + spec = "net = Fs(10) + Fr(10)" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [17, 10]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (17, 10)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ var dot var biasadd sig " + "<> var dot var biasadd relu add") + + def testMpPower(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 64, 64, 5)) + spec = "M2 = Mp([2, 2]); net = M2**3" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 8, 8, 5]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 8, 8, 5)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ maxpool maxpool maxpool") + + def testAbbrevPower(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 64, 64, 5)) + spec = "C3 = Cr([3, 3]); M2 = Mp([2, 2]); net = (C3(5) | M2)**3" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 8, 8, 5]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 8, 8, 5)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ var conv var biasadd relu maxpool var conv var" + " biasadd relu maxpool var conv var" + " biasadd relu maxpool") + + def testAbbrevPower2(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 64, 64, 5)) + spec = "C3 = Cr(_1=[3, 3]); M2 = Mp([2, 2]);" + spec += "net = (C3(_0=5) | M2)**3" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 8, 8, 5]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 8, 8, 5)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ var conv var biasadd relu maxpool var conv" + " var biasadd relu" + " maxpool var conv var biasadd relu maxpool") + + def testConc(self): + with self.test_session(): + inputs = tf.constant(_rand(10, 20)) + spec = "net = Conc(1, Fs(20), Fs(10))" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [10, 30]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (10, 30)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ _ var dot var biasadd sig " + "<> var dot var biasadd sig concat") + + def testImport(self): + with self.test_session(): + inputs = tf.constant(_rand(10, 20)) + spec = "S = Import('import tensorflow as tf; f = tf.nn.sigmoid')" + spec += "; net = S | S" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [10, 20]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (10, 20)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ sig sig") + + def testLstm2(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 64, 64, 5)) + spec = "net = Lstm2(15)" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 64, 64, 15]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 64, 64, 15)) + + def testLstm2to1(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 64, 64, 5)) + spec = "net = Lstm2to1(15)" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 64, 15]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 64, 15)) + + def testLstm2to0(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 64, 64, 5)) + spec = "net = Lstm2to0(15)" + outputs = specs.create_net(spec, inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 15]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 15)) + + def testKeywordRestriction(self): + with self.test_session(): + inputs = tf.constant(_rand(10, 20)) + spec = "import re; net = Conc(1, Fs(20), Fs(10))" + self.assertRaises(ValueError, lambda: specs.create_net(spec, inputs)) + + def testParams(self): + params = "x = 3; y = Ui(-10, 10); z = Lf(1, 100); q = Nt(0.0, 1.0)" + bindings = specs.eval_params(params, {}) + self.assertTrue("x" in bindings) + self.assertEqual(bindings["x"], 3) + self.assertTrue("y" in bindings) + self.assertTrue("z" in bindings) + self.assertTrue("q" in bindings) + + def testSpecsOps(self): + # pylint: disable=undefined-variable + with self.assertRaises(NameError): + _ = Cr + with specs.ops: + self.assertIsNotNone(Cr) + self.assertTrue(callable(Cr(64, [3, 3]))) + with self.assertRaises(NameError): + _ = Cr + + def testVar(self): + with self.test_session() as sess: + with specs.ops: + # pylint: disable=undefined-variable + v = Var("test_var", shape=[2, 2], + initializer=tf.constant_initializer(42.0)) + inputs = tf.constant(_rand(10, 100)) + outputs = v.funcall(inputs) + self.assertEqual(len(tf.all_variables()), 1) + sess.run([outputs.initializer]) + outputs_value = outputs.eval() + self.assertEqual(outputs_value.shape, (2, 2)) + self.assertEqual(outputs_value[1, 1], 42.0) + + def testShared(self): + with self.test_session(): + with specs.ops: + # pylint: disable=undefined-variable + f = Shared(Fr(100)) + g = f | f | f | f + inputs = tf.constant(_rand(10, 100)) + _ = g.funcall(inputs) + self.assertEqual(len(tf.all_variables()), 2) + + def testAutoFunction(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 18, 19, 5)) + with specs.ops: + # pylint: disable=undefined-variable + net = SL.conv2d(64, 5) + outputs = net.funcall(inputs) + self.assertEqual(outputs.get_shape().as_list(), [1, 18, 19, 64]) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 18, 19, 64)) + self.assertEqual(summaries.tf_spec_structure("net = Cr(64, 5)", inputs), + "_ var conv var biasadd relu") + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/specs/python/summaries.py b/tensorflow/contrib/specs/python/summaries.py new file mode 100644 index 0000000000..27f3bb32d7 --- /dev/null +++ b/tensorflow/contrib/specs/python/summaries.py @@ -0,0 +1,301 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for summarizing and describing TensorFlow graphs. + +This contains functions that generate string descriptions from +TensorFlow graphs, for debugging, testing, and model size +estimation. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import re +import tensorflow as tf +from tensorflow.contrib.specs.python import specs + + +# These are short abbreviations for common TensorFlow operations used +# in test cases with tf_structure to verify that specs_lib generates a +# graph structure with the right operations. Operations outside the +# scope of specs (e.g., Const and Placeholder) are just assigned "_" +# since they are not relevant to testing. + +SHORT_NAMES_SRC = """ +BiasAdd biasadd +Const _ +Conv2D conv +MatMul dot +Placeholder _ +Sigmoid sig +Variable var +""".split() + + +SHORT_NAMES = {x: y for x, y in zip(SHORT_NAMES_SRC[::2], + SHORT_NAMES_SRC[1::2])} + + +def _truncate_structure(x): + """A helper function that disables recursion in tf_structure. + + Some constructs (e.g., HorizontalLstm) are complex unrolled + structures and don't need to be represented in the output + of tf_structure or tf_print. This helper function defines + which tree branches should be pruned. This is a very imperfect + way of dealing with unrolled LSTM's (since it truncates + useful information as well), but it's not worth doing something + better until the new fused and unrolled ops are ready. + + Args: + x: a Tensor or Op + + Returns: + A bool indicating whether the subtree should be pruned. + """ + if "/HorizontalLstm/" in x.name: return True + return False + + +def tf_structure(x, include_shapes=False, finished=None): + """A postfix expression summarizing the TF graph. + + This is intended to be used as part of test cases to + check for gross differences in the structure of the graph. + The resulting string is not invertible or unabiguous + and cannot be used to reconstruct the graph accurately. + + Args: + x: a tf.Tensor or tf.Operation + include_shapes: include shapes in the output string + finished: a set of ops that have already been output + + Returns: + A string representing the structure as a string of + postfix operations. + """ + if finished is None: + finished = set() + if isinstance(x, tf.Tensor): + shape = x.get_shape().as_list() + x = x.op + else: + shape = [] + if x in finished: + return " <>" + finished |= {x} + result = "" + if not _truncate_structure(x): + for y in x.inputs: + result += tf_structure(y, include_shapes, finished) + if include_shapes: + result += " %s" % (shape,) + if x.type != "Identity": + name = SHORT_NAMES.get(x.type, x.type.lower()) + result += " " + name + return result + + +def tf_print(x, depth=0, finished=None, printer=print): + """A simple print function for a TensorFlow graph. + + Args: + x: a tf.Tensor or tf.Operation + depth: current printing depth + finished: set of nodes already output + printer: print function to use + + Returns: + Total number of parameters found in the + subtree. + """ + + if finished is None: + finished = set() + if isinstance(x, tf.Tensor): + shape = x.get_shape().as_list() + x = x.op + else: + shape = "" + if x.type == "Identity": + x = x.inputs[0].op + if x in finished: + printer("%s<%s> %s %s" % (" "*depth, x.name, x.type, shape)) + return + finished |= {x} + printer("%s%s %s %s" % (" "*depth, x.name, x.type, shape)) + if not _truncate_structure(x): + for y in x.inputs: + tf_print(y, depth+1, finished, printer=printer) + + +def tf_num_params(x): + """Number of parameters in a TensorFlow subgraph. + + Args: + x: root of the subgraph (Tensor, Operation) + + Returns: + Total number of elements found in all Variables + in the subgraph. + """ + + if isinstance(x, tf.Tensor): + shape = x.get_shape() + x = x.op + if x.type == "Variable": + return shape.num_elements() + totals = [tf_num_params(y) for y in x.inputs] + return sum(totals) + + +def tf_left_split(op): + """Split the parameters of op for left recursion. + + Args: + op: tf.Operation + + Returns: + A tuple of the leftmost input tensor and a list of the + remaining arguments. + """ + + if len(op.inputs) < 1: + return None, [] + if op.type == "Concat": + return op.inputs[1], op.inputs[2:] + return op.inputs[0], op.inputs[1:] + + +def tf_parameter_iter(x): + """Iterate over the left branches of a graph and yield sizes. + + Args: + x: root of the subgraph (Tensor, Operation) + + Yields: + A triple of name, number of params, and shape. + """ + + while 1: + if isinstance(x, tf.Tensor): + shape = x.get_shape().as_list() + x = x.op + else: + shape = "" + left, right = tf_left_split(x) + totals = [tf_num_params(y) for y in right] + total = sum(totals) + yield x.name, total, shape + if left is None: break + x = left + + +def _combine_filter(x): + """A filter for combining successive layers with similar names.""" + last_name = None + last_total = 0 + last_shape = None + for name, total, shape in x: + name = re.sub("/.*", "", name) + if name == last_name: + last_total += total + continue + if last_name is not None: + yield last_name, last_total, last_shape + last_name = name + last_total = total + last_shape = shape + if last_name is not None: + yield last_name, last_total, last_shape + + +def tf_parameter_summary(x, printer=print, combine=True): + """Summarize parameters by depth. + + Args: + x: root of the subgraph (Tensor, Operation) + printer: print function for output + combine: combine layers by top-level scope + """ + seq = tf_parameter_iter(x) + if combine: seq = _combine_filter(seq) + seq = reversed(list(seq)) + for name, total, shape in seq: + printer("%10d %-20s %s" % (total, name, shape)) + + +def tf_spec_structure(spec, inputs=None, input_shape=None, + input_type=tf.float32): + """Return a postfix representation of the specification. + + This is intended to be used as part of test cases to + check for gross differences in the structure of the graph. + The resulting string is not invertible or unabiguous + and cannot be used to reconstruct the graph accurately. + + Args: + spec: specification + inputs: input to the spec construction (usually a Tensor) + input_shape: tensor shape (in lieu of inputs) + input_type: type of the input tensor + + Returns: + A string with a postfix representation of the + specification. + """ + + if inputs is None: + inputs = tf.placeholder(input_type, input_shape) + outputs = specs.create_net(spec, inputs) + return str(tf_structure(outputs).strip()) + + +def tf_spec_summary(spec, inputs=None, input_shape=None, input_type=tf.float32): + """Output a summary of the specification. + + This prints a list of left-most tensor operations and summarized the + variables found in the right branches. This kind of representation + is particularly useful for networks that are generally structured + like pipelines. + + Args: + spec: specification + inputs: input to the spec construction (usually a Tensor) + input_shape: optional shape of input + input_type: type of the input tensor + """ + + if inputs is None: + inputs = tf.placeholder(input_type, input_shape) + outputs = specs.create_net(spec, inputs) + tf_parameter_summary(outputs) + + +def tf_spec_print(spec, inputs=None, input_shape=None, input_type=tf.float32): + """Print a tree representing the spec. + + Args: + spec: specification + inputs: input to the spec construction (usually a Tensor) + input_shape: optional shape of input + input_type: type of the input tensor + """ + + if inputs is None: + inputs = tf.placeholder(input_type, input_shape) + outputs = specs.create_net(spec, inputs) + tf_print(outputs) diff --git a/tensorflow/contrib/specs/python/summaries_test.py b/tensorflow/contrib/specs/python/summaries_test.py new file mode 100644 index 0000000000..77f01b6549 --- /dev/null +++ b/tensorflow/contrib/specs/python/summaries_test.py @@ -0,0 +1,80 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for specs-related summarization functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.specs.python import specs +from tensorflow.contrib.specs.python import summaries + + +def _rand(*size): + return np.random.uniform(size=size).astype("f") + + +class SummariesTest(tf.test.TestCase): + + def testStructure(self): + with self.test_session(): + inputs_shape = (1, 18, 19, 5) + inputs = tf.constant(_rand(*inputs_shape)) + spec = "net = Cr(64, [5, 5])" + outputs = specs.create_net(spec, inputs) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 18, 19, 64)) + self.assertEqual(summaries.tf_spec_structure(spec, + input_shape=inputs_shape), + "_ var conv var biasadd relu") + + def testStructureFromTensor(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 18, 19, 5)) + spec = "net = Cr(64, [5, 5])" + outputs = specs.create_net(spec, inputs) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 18, 19, 64)) + self.assertEqual(summaries.tf_spec_structure(spec, inputs), + "_ var conv var biasadd relu") + + def testPrint(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 18, 19, 5)) + spec = "net = Cr(64, [5, 5])" + outputs = specs.create_net(spec, inputs) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 18, 19, 64)) + summaries.tf_spec_print(spec, inputs) + + def testSummary(self): + with self.test_session(): + inputs = tf.constant(_rand(1, 18, 19, 5)) + spec = "net = Cr(64, [5, 5])" + outputs = specs.create_net(spec, inputs) + tf.initialize_all_variables().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (1, 18, 19, 64)) + summaries.tf_spec_summary(spec, inputs) + + +if __name__ == "__main__": + tf.test.main() -- cgit v1.2.3