aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Dustin Tran <trandustin@google.com>2018-02-19 21:39:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-19 21:42:41 -0800
commit1ad338200e2643387efe6bebd1fcd59ddd87fdf1 (patch)
treed5b332adde3bef939f63bb498a32e1dd95ee5309 /tensorflow/contrib/bayesflow
parentff6c4de87cbb23be97c4a10e9cb37fe13d2cb3a4 (diff)
Reduce tfp.layers boilerplate via programmable docstrings.
PiperOrigin-RevId: 186260342
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD10
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py83
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/docstring_util.py86
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py1127
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py391
5 files changed, 577 insertions, 1120 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 74712aeb67..fc04933ba0 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -119,6 +119,16 @@ cuda_py_test(
)
cuda_py_test(
+ name = "docstring_util_test",
+ size = "small",
+ srcs = ["python/kernel_tests/docstring_util_test.py"],
+ additional_deps = [
+ ":bayesflow_py",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+cuda_py_test(
name = "layers_dense_variational_test",
size = "small",
srcs = ["python/kernel_tests/layers_dense_variational_test.py"],
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py
new file mode 100644
index 0000000000..09ae6f3952
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for docstring utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bayesflow.python.ops import docstring_util
+from tensorflow.python.platform import test
+
+
+class DocstringUtil(test.TestCase):
+
+ def _testFunction(self):
+ doc_args = """ x: Input to return as output.
+ y: Baz."""
+ @docstring_util.expand_docstring(args=doc_args)
+ def foo(x):
+ """Hello world.
+
+ Args:
+ @{args}
+
+ Returns:
+ x.
+ """
+ return x
+
+ true_docstring = """Hello world.
+
+ Args:
+ x: Input to return as output.
+ y: Baz.
+
+ Returns:
+ x.
+ """
+ self.assertEqual(foo.__doc__, true_docstring)
+
+ def _testClassInit(self):
+ doc_args = """ x: Input to return as output.
+ y: Baz."""
+
+ class Foo(object):
+
+ @docstring_util.expand_docstring(args=doc_args)
+ def __init__(self, x, y):
+ """Hello world.
+
+ Args:
+ @{args}
+
+ Bar.
+ """
+ pass
+
+ true_docstring = """Hello world.
+
+ Args:
+ x: Input to return as output.
+ y: Baz.
+
+ Bar.
+ """
+ self.assertEqual(Foo.__doc__, true_docstring)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/docstring_util.py b/tensorflow/contrib/bayesflow/python/ops/docstring_util.py
new file mode 100644
index 0000000000..44a1ea2f2a
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/docstring_util.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for programmable docstrings.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import sys
+import six
+
+
+def expand_docstring(**kwargs):
+ """Decorator to programmatically expand the docstring.
+
+ Args:
+ **kwargs: Keyword arguments to set. For each key-value pair `k` and `v`,
+ the key is found as `@{k}` in the docstring and replaced with `v`.
+
+ Returns:
+ Decorated function.
+ """
+ def _fn_wrapped(fn):
+ """Original function with modified `__doc__` attribute."""
+ doc = _trim(fn.__doc__)
+ for k, v in six.iteritems(kwargs):
+ # Capture each @{k} reference to replace with v.
+ # We wrap the replacement in a function so no backslash escapes
+ # are processed.
+ pattern = r'@\{' + str(k) + r'\}'
+ doc = re.sub(pattern, lambda match: v, doc) # pylint: disable=cell-var-from-loop
+ fn.__doc__ = doc
+ return fn
+ return _fn_wrapped
+
+
+def _trim(docstring):
+ """Trims docstring indentation.
+
+ In general, multi-line docstrings carry their level of indentation when
+ defined under a function or class method. This function standardizes
+ indentation levels by removing them. Taken from PEP 257 docs.
+
+ Args:
+ docstring: Python string to trim indentation.
+
+ Returns:
+ Trimmed docstring.
+ """
+ if not docstring:
+ return ''
+ # Convert tabs to spaces (following the normal Python rules)
+ # and split into a list of lines:
+ lines = docstring.expandtabs().splitlines()
+ # Determine minimum indentation (first line doesn't count):
+ indent = sys.maxint
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ # Remove indentation (first line is special):
+ trimmed = [lines[0].strip()]
+ if indent < sys.maxint:
+ for line in lines[1:]:
+ trimmed.append(line[indent:].rstrip())
+ # Strip off trailing and leading blank lines:
+ while trimmed and not trimmed[-1]:
+ trimmed.pop()
+ while trimmed and not trimmed[0]:
+ trimmed.pop(0)
+ # Return a single string:
+ return '\n'.join(trimmed)
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
index 7723cfb442..90219fdfef 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.bayesflow.python.ops import docstring_util
from tensorflow.contrib.bayesflow.python.ops import layers_util
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.python.framework import dtypes
@@ -34,6 +35,45 @@ from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.ops.distributions import util as distribution_util
+doc_args = """ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ name: A string, the name of the layer."""
+
class _ConvVariational(layers_lib.Layer):
"""Abstract nD convolution layer (private, used as implementation base).
@@ -55,65 +95,6 @@ class _ConvVariational(layers_lib.Layer):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of n integers, specifying the
- length of the convolution window.
- strides: An integer or tuple/list of n integers,
- specifying the stride length of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, ..., channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, ...)`.
- dilation_rate: An integer or tuple/list of n integers, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: A string, the name of the layer.
-
Properties:
rank: Python integer, dimensionality of convolution.
filters: Python integer, dimensionality of the output space.
@@ -134,6 +115,7 @@ class _ConvVariational(layers_lib.Layer):
bias_divergence_fn: `callable` returning divergence.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
rank,
@@ -157,6 +139,31 @@ class _ConvVariational(layers_lib.Layer):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D
+ convolution.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ length of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, ...,
+ channels)` while `channels_first` corresponds to inputs with shape
+ `(batch, channels, ...)`.
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ @{args}
+ """
super(_ConvVariational, self).__init__(
trainable=trainable,
name=name,
@@ -371,65 +378,6 @@ class _ConvReparameterization(_ConvVariational):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of n integers, specifying the
- length of the convolution window.
- strides: An integer or tuple/list of n integers,
- specifying the stride length of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, ..., channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, ...)`.
- dilation_rate: An integer or tuple/list of n integers, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: A string, the name of the layer.
-
Properties:
rank: Python integer, dimensionality of convolution.
filters: Python integer, dimensionality of the output space.
@@ -454,6 +402,7 @@ class _ConvReparameterization(_ConvVariational):
International Conference on Learning Representations, 2014.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
rank,
@@ -477,6 +426,31 @@ class _ConvReparameterization(_ConvVariational):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D
+ convolution.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ length of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, ...,
+ channels)` while `channels_first` corresponds to inputs with shape
+ `(batch, channels, ...)`.
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ @{args}
+ """
super(_ConvReparameterization, self).__init__(
rank=rank,
filters=filters,
@@ -529,63 +503,6 @@ class Conv1DReparameterization(_ConvReparameterization):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of a single integer, specifying the
- length of the 1D convolution window.
- strides: An integer or tuple/list of a single integer,
- specifying the stride length of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
- dilation_rate: An integer or tuple/list of a single integer, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- name: A string, the name of the layer.
-
Properties:
filters: Python integer, dimensionality of the output space.
kernel_size: Size of the convolution window.
@@ -639,6 +556,7 @@ class Conv1DReparameterization(_ConvReparameterization):
International Conference on Learning Representations, 2014.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
filters,
@@ -661,6 +579,29 @@ class Conv1DReparameterization(_ConvReparameterization):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of a single integer, specifying the
+ length of the 1D convolution window.
+ strides: An integer or tuple/list of a single integer,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, length,
+ channels)` while `channels_first` corresponds to inputs with shape
+ `(batch, channels, length)`.
+ dilation_rate: An integer or tuple/list of a single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ @{args}
+ """
super(Conv1DReparameterization, self).__init__(
rank=1,
filters=filters,
@@ -683,6 +624,7 @@ class Conv1DReparameterization(_ConvReparameterization):
name=name, **kwargs)
+@docstring_util.expand_docstring(args=doc_args)
def conv1d_reparameterization(
inputs,
filters,
@@ -726,7 +668,7 @@ def conv1d_reparameterization(
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
+ Args:
inputs: Tensor input.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
@@ -746,43 +688,7 @@ def conv1d_reparameterization(
the dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- name: A string, the name of the layer.
+ @{args}
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
@@ -874,70 +780,6 @@ class Conv2DReparameterization(_ConvReparameterization):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of 2 integers, specifying the
- height and width of the 2D convolution window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the height and width.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, height, width, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, height, width)`.
-
- dilation_rate: An integer or tuple/list of 2 integers, specifying
- the dilation rate to use for dilated convolution.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- name: A string, the name of the layer.
-
Properties:
filters: Python integer, dimensionality of the output space.
kernel_size: Size of the convolution window.
@@ -994,6 +836,7 @@ class Conv2DReparameterization(_ConvReparameterization):
International Conference on Learning Representations, 2014.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
filters,
@@ -1016,6 +859,35 @@ class Conv2DReparameterization(_ConvReparameterization):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ height and width of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, height,
+ width, channels)` while `channels_first` corresponds to inputs with
+ shape `(batch, channels, height, width)`.
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ @{args}
+ """
super(Conv2DReparameterization, self).__init__(
rank=2,
filters=filters,
@@ -1038,6 +910,7 @@ class Conv2DReparameterization(_ConvReparameterization):
name=name, **kwargs)
+@docstring_util.expand_docstring(args=doc_args)
def conv2d_reparameterization(
inputs,
filters,
@@ -1081,7 +954,7 @@ def conv2d_reparameterization(
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
+ Args:
inputs: Tensor input.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
@@ -1101,50 +974,13 @@ def conv2d_reparameterization(
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
-
dilation_rate: An integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- name: A string, the name of the layer.
+ @{args}
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
@@ -1240,71 +1076,6 @@ class Conv3DReparameterization(_ConvReparameterization):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of 3 integers, specifying the
- depth, height and width of the 3D convolution window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 3 integers,
- specifying the strides of the convolution along the depth,
- height and width.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, depth, height, width, channels)` while `channels_first`
- corresponds to inputs with shape
- `(batch, channels, depth, height, width)`.
- dilation_rate: An integer or tuple/list of 3 integers, specifying
- the dilation rate to use for dilated convolution.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- name: A string, the name of the layer.
-
Properties:
filters: Python integer, dimensionality of the output space.
kernel_size: Size of the convolution window.
@@ -1361,6 +1132,7 @@ class Conv3DReparameterization(_ConvReparameterization):
International Conference on Learning Representations, 2014.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
filters,
@@ -1383,6 +1155,36 @@ class Conv3DReparameterization(_ConvReparameterization):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 3 integers, specifying the
+ depth, height and width of the 3D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 3 integers,
+ specifying the strides of the convolution along the depth,
+ height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, depth,
+ height, width, channels)` while `channels_first` corresponds to inputs
+ with shape `(batch, channels, depth, height, width)`.
+ dilation_rate: An integer or tuple/list of 3 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ @{args}
+ """
super(Conv3DReparameterization, self).__init__(
rank=3,
filters=filters,
@@ -1405,6 +1207,7 @@ class Conv3DReparameterization(_ConvReparameterization):
name=name, **kwargs)
+@docstring_util.expand_docstring(args=doc_args)
def conv3d_reparameterization(
inputs,
filters,
@@ -1448,7 +1251,7 @@ def conv3d_reparameterization(
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
+ Args:
inputs: Tensor input.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
@@ -1476,43 +1279,7 @@ def conv3d_reparameterization(
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- name: A string, the name of the layer.
+ @{args}
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
@@ -1611,67 +1378,6 @@ class _ConvFlipout(_ConvVariational):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of n integers, specifying the
- length of the convolution window.
- strides: An integer or tuple/list of n integers,
- specifying the stride length of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, ..., channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, ...)`.
- dilation_rate: An integer or tuple/list of n integers, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
-
Properties:
rank: Python integer, dimensionality of convolution.
filters: Python integer, dimensionality of the output space.
@@ -1694,10 +1400,11 @@ class _ConvFlipout(_ConvVariational):
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
rank,
@@ -1722,6 +1429,31 @@ class _ConvFlipout(_ConvVariational):
seed=None,
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D
+ convolution.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ length of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, ...,
+ channels)` while `channels_first` corresponds to inputs with shape
+ `(batch, channels, ...)`.
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ @{args}
+ """
super(_ConvFlipout, self).__init__(
rank=rank,
filters=filters,
@@ -1822,65 +1554,6 @@ class Conv1DFlipout(_ConvFlipout):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of a single integer, specifying the
- length of the 1D convolution window.
- strides: An integer or tuple/list of a single integer,
- specifying the stride length of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
- dilation_rate: An integer or tuple/list of a single integer, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
-
Properties:
filters: Python integer, dimensionality of the output space.
kernel_size: Size of the convolution window.
@@ -1932,10 +1605,11 @@ class Conv1DFlipout(_ConvFlipout):
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
filters,
@@ -1959,6 +1633,29 @@ class Conv1DFlipout(_ConvFlipout):
seed=None,
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of a single integer, specifying the
+ length of the 1D convolution window.
+ strides: An integer or tuple/list of a single integer,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, length,
+ channels)` while `channels_first` corresponds to inputs with shape
+ `(batch, channels, length)`.
+ dilation_rate: An integer or tuple/list of a single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ @{args}
+ """
super(Conv1DFlipout, self).__init__(
rank=1,
filters=filters,
@@ -1982,6 +1679,7 @@ class Conv1DFlipout(_ConvFlipout):
name=name, **kwargs)
+@docstring_util.expand_docstring(args=doc_args)
def conv1d_flipout(
inputs,
filters,
@@ -2029,7 +1727,7 @@ def conv1d_flipout(
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
+ Args:
inputs: Tensor input.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
@@ -2049,45 +1747,7 @@ def conv1d_flipout(
the dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
+ @{args}
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
@@ -2130,8 +1790,8 @@ def conv1d_flipout(
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
layer = Conv1DFlipout(
filters=filters,
@@ -2184,72 +1844,6 @@ class Conv2DFlipout(_ConvFlipout):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of 2 integers, specifying the
- height and width of the 2D convolution window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the height and width.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, height, width, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, height, width)`.
-
- dilation_rate: An integer or tuple/list of 2 integers, specifying
- the dilation rate to use for dilated convolution.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
-
Properties:
filters: Python integer, dimensionality of the output space.
kernel_size: Size of the convolution window.
@@ -2304,10 +1898,11 @@ class Conv2DFlipout(_ConvFlipout):
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
filters,
@@ -2331,6 +1926,35 @@ class Conv2DFlipout(_ConvFlipout):
seed=None,
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ height and width of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, height,
+ width, channels)` while `channels_first` corresponds to inputs with
+ shape `(batch, channels, height, width)`.
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ @{args}
+ """
super(Conv2DFlipout, self).__init__(
rank=2,
filters=filters,
@@ -2354,6 +1978,7 @@ class Conv2DFlipout(_ConvFlipout):
name=name, **kwargs)
+@docstring_util.expand_docstring(args=doc_args)
def conv2d_flipout(
inputs,
filters,
@@ -2401,7 +2026,7 @@ def conv2d_flipout(
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
+ Args:
inputs: Tensor input.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
@@ -2421,52 +2046,13 @@ def conv2d_flipout(
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
-
dilation_rate: An integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
+ @{args}
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
@@ -2513,8 +2099,8 @@ def conv2d_flipout(
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
layer = Conv2DFlipout(
filters=filters,
@@ -2567,73 +2153,6 @@ class Conv3DFlipout(_ConvFlipout):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of 3 integers, specifying the
- depth, height and width of the 3D convolution window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 3 integers,
- specifying the strides of the convolution along the depth,
- height and width.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, depth, height, width, channels)` while `channels_first`
- corresponds to inputs with shape
- `(batch, channels, depth, height, width)`.
- dilation_rate: An integer or tuple/list of 3 integers, specifying
- the dilation rate to use for dilated convolution.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
-
Properties:
filters: Python integer, dimensionality of the output space.
kernel_size: Size of the convolution window.
@@ -2688,10 +2207,11 @@ class Conv3DFlipout(_ConvFlipout):
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
filters,
@@ -2715,6 +2235,36 @@ class Conv3DFlipout(_ConvFlipout):
seed=None,
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 3 integers, specifying the
+ depth, height and width of the 3D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 3 integers,
+ specifying the strides of the convolution along the depth,
+ height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, depth,
+ height, width, channels)` while `channels_first` corresponds to inputs
+ with shape `(batch, channels, depth, height, width)`.
+ dilation_rate: An integer or tuple/list of 3 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ @{args}
+ """
super(Conv3DFlipout, self).__init__(
rank=3,
filters=filters,
@@ -2738,6 +2288,7 @@ class Conv3DFlipout(_ConvFlipout):
name=name, **kwargs)
+@docstring_util.expand_docstring(args=doc_args)
def conv3d_flipout(
inputs,
filters,
@@ -2785,7 +2336,7 @@ def conv3d_flipout(
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Arguments:
+ Args:
inputs: Tensor input.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
@@ -2813,45 +2364,7 @@ def conv3d_flipout(
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- activity_regularizer: Optional regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: A string, the name of the layer.
+ @{args}
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
@@ -2898,8 +2411,8 @@ def conv3d_flipout(
[1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
Mini-Batches."
- Anonymous. OpenReview, 2017.
- https://openreview.net/forum?id=rJnpifWAb
+ Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse.
+ International Conference on Learning Representations, 2018.
"""
layer = Conv3DFlipout(
filters=filters,
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py
index 591a8e553d..1e4a445a33 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.bayesflow.python.ops import docstring_util
from tensorflow.contrib.bayesflow.python.ops import layers_util
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.python.framework import dtypes
@@ -33,6 +34,53 @@ from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.ops.distributions import util as distribution_util
+doc_args = """ units: Integer or Long, dimensionality of the output space.
+ activation: Activation function (`callable`). Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: Python `str`, the name of the layer. Layers with the same name will
+ share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
+ such cases.
+ reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
+ layer by the same name."""
+
+
class _DenseVariational(layers_lib.Layer):
"""Abstract densely-connected class (private, used as implementation base).
@@ -50,51 +98,6 @@ class _DenseVariational(layers_lib.Layer):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Args:
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
-
Properties:
units: Python integer, dimensionality of the output space.
activation: Activation function (`callable`).
@@ -109,6 +112,7 @@ class _DenseVariational(layers_lib.Layer):
bias_divergence_fn: `callable` returning divergence.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
units,
@@ -126,6 +130,11 @@ class _DenseVariational(layers_lib.Layer):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ @{args}
+ """
super(_DenseVariational, self).__init__(
trainable=trainable,
name=name,
@@ -274,51 +283,6 @@ class DenseReparameterization(_DenseVariational):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Args:
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
-
Properties:
units: Python integer, dimensionality of the output space.
activation: Activation function (`callable`).
@@ -363,6 +327,7 @@ class DenseReparameterization(_DenseVariational):
International Conference on Learning Representations, 2014.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
units,
@@ -381,6 +346,11 @@ class DenseReparameterization(_DenseVariational):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ @{args}
+ """
super(DenseReparameterization, self).__init__(
units=units,
activation=activation,
@@ -405,6 +375,7 @@ class DenseReparameterization(_DenseVariational):
return self._matmul(inputs, self.kernel_posterior_tensor)
+@docstring_util.expand_docstring(args=doc_args)
def dense_reparameterization(
inputs,
units,
@@ -444,49 +415,7 @@ def dense_reparameterization(
Args:
inputs: Tensor input.
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
+ @{args}
Returns:
output: `Tensor` representing a the affine transformed input under a random
@@ -563,51 +492,6 @@ class DenseLocalReparameterization(_DenseVariational):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Args:
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
-
Properties:
units: Python integer, dimensionality of the output space.
activation: Activation function (`callable`).
@@ -652,6 +536,7 @@ class DenseLocalReparameterization(_DenseVariational):
Neural Information Processing Systems, 2015.
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
units,
@@ -670,6 +555,11 @@ class DenseLocalReparameterization(_DenseVariational):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ @{args}
+ """
super(DenseLocalReparameterization, self).__init__(
units=units,
activation=activation,
@@ -705,6 +595,7 @@ class DenseLocalReparameterization(_DenseVariational):
return self.kernel_posterior_affine_tensor
+@docstring_util.expand_docstring(args=doc_args)
def dense_local_reparameterization(
inputs,
units,
@@ -745,49 +636,7 @@ def dense_local_reparameterization(
Args:
inputs: Tensor input.
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
+ @{args}
Returns:
output: `Tensor` representing a the affine transformed input under a random
@@ -866,53 +715,6 @@ class DenseFlipout(_DenseVariational):
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
- Args:
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
-
Properties:
units: Python integer, dimensionality of the output space.
activation: Activation function (`callable`).
@@ -959,6 +761,7 @@ class DenseFlipout(_DenseVariational):
https://openreview.net/forum?id=rJnpifWAb
"""
+ @docstring_util.expand_docstring(args=doc_args)
def __init__(
self,
units,
@@ -978,6 +781,11 @@ class DenseFlipout(_DenseVariational):
seed=None,
name=None,
**kwargs):
+ """Construct layer.
+
+ Args:
+ @{args}
+ """
super(DenseFlipout, self).__init__(
units=units,
activation=activation,
@@ -1031,6 +839,7 @@ class DenseFlipout(_DenseVariational):
return outputs
+@docstring_util.expand_docstring(args=doc_args)
def dense_flipout(
inputs,
units,
@@ -1074,51 +883,7 @@ def dense_flipout(
Args:
inputs: Tensor input.
- units: Integer or Long, dimensionality of the output space.
- activation: Activation function (`callable`). Set it to None to maintain a
- linear activation.
- activity_regularizer: Regularizer function for the output.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- kernel_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `kernel` parameter. Default value:
- `default_mean_field_normal_fn()`.
- kernel_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- kernel_prior_fn: Python `callable` which creates `tf.distributions`
- instance. See `default_mean_field_normal_fn` docstring for required
- parameter signature.
- Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
- kernel_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- bias_posterior_fn: Python `callable` which creates
- `tf.distributions.Distribution` instance representing the surrogate
- posterior of the `bias` parameter. Default value:
- `default_mean_field_normal_fn(is_singular=True)` (which creates an
- instance of `tf.distributions.Deterministic`).
- bias_posterior_tensor_fn: Python `callable` which takes a
- `tf.distributions.Distribution` instance and returns a representative
- value. Default value: `lambda d: d.sample()`.
- bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
- See `default_mean_field_normal_fn` docstring for required parameter
- signature. Default value: `None` (no prior, no variational inference)
- bias_divergence_fn: Python `callable` which takes the surrogate posterior
- distribution, prior distribution and random variate sample(s) from the
- surrogate posterior and computes or approximates the KL divergence. The
- distributions are `tf.distributions.Distribution`-like instances and the
- sample is a `Tensor`.
- seed: Python scalar `int` which initializes the random number
- generator. Default value: `None` (i.e., use global seed).
- name: Python `str`, the name of the layer. Layers with the same name will
- share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in
- such cases.
- reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous
- layer by the same name.
+ @{args}
Returns:
output: `Tensor` representing a the affine transformed input under a random