aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-02-06 13:03:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-06 13:11:16 -0800
commitd7b8bce5c6e84f2a3cb5e7f5f5255a1d30afffb2 (patch)
treeee1c0701970c23cd84f521672fde1557eebf2ef5 /tensorflow/contrib/bayesflow
parent6f7c6763b80088152b6f9de5ba1046e75d28a26a (diff)
Add utility function which makes implicit `tf.get_variable` dependencies an
explicit argument of a callable. PiperOrigin-RevId: 184725878
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD17
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py135
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/variable_utils.py29
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py157
5 files changed, 340 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 82944f5363..34156c28fe 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -240,6 +240,23 @@ cuda_py_test(
)
cuda_py_test(
+ name = "variable_utils_test",
+ size = "small",
+ srcs = ["python/kernel_tests/variable_utils_test.py"],
+ additional_deps = [
+ ":bayesflow_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "variational_sgd_optimizer_test",
size = "small",
srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"],
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index c411026346..528c4fbacd 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
from tensorflow.contrib.bayesflow.python.ops import optimizers
+from tensorflow.contrib.bayesflow.python.ops import variable_utils
# pylint: enable=unused-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
@@ -48,6 +49,7 @@ _allowed_symbols = [
'optimizers',
'special_math',
'stochastic_variables',
+ 'variable_utils',
'variational_inference',
]
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py
new file mode 100644
index 0000000000..f978cf8641
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py
@@ -0,0 +1,135 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utility functions related to managing `tf.Variable`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+import numpy as np
+
+from tensorflow.contrib.bayesflow.python.ops import variable_utils
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope as varscope_ops
+from tensorflow.python.ops import variables as variables_ops
+from tensorflow.python.platform import test
+
+
+def test_fn(x):
+ x = ops.convert_to_tensor(x, name="x")
+ dtype = x.dtype.as_numpy_dtype
+ s = x.shape.as_list()
+ z = varscope_ops.get_variable(
+ name="z",
+ dtype=dtype,
+ initializer=np.arange(np.prod(s)).reshape(s).astype(dtype))
+ y = varscope_ops.get_variable(
+ name="y",
+ dtype=dtype,
+ initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2)
+ return x + y + z
+
+
+class _WrapCallableTest(object):
+
+ def testDefaultArgsWorkCorrectly(self):
+ with self.test_session():
+ x = constant_op.constant(self.dtype([0.1, 0.2]))
+ wrapped_fn, vars_args = variable_utils.externalize_variables_as_args(
+ test_fn, [x])
+
+ varscope_ops.get_variable_scope().reuse_variables()
+
+ result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5)
+
+ y_actual = varscope_ops.get_variable("y", dtype=self.dtype)
+ z_actual = varscope_ops.get_variable("z", dtype=self.dtype)
+
+ variables_ops.global_variables_initializer().run()
+ result_ = result.eval()
+
+ self.assertEqual(self.dtype, result_.dtype)
+ self.assertAllEqual([5.5, 6.5, 7.5], result_)
+ self.assertAllEqual([y_actual, z_actual], vars_args)
+
+ def testNonDefaultArgsWorkCorrectly(self):
+ with self.test_session():
+ x = constant_op.constant(self.dtype([0.1, 0.2]))
+
+ _ = test_fn(self.dtype([0., 0.])) # Needed to create vars.
+ varscope_ops.get_variable_scope().reuse_variables()
+
+ y_actual = varscope_ops.get_variable("y", dtype=self.dtype)
+
+ wrapped_fn, vars_args = variable_utils.externalize_variables_as_args(
+ test_fn, [x], possible_ancestor_vars=[y_actual])
+
+ result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y
+
+ variables_ops.global_variables_initializer().run()
+ result_ = result.eval()
+
+ self.assertEqual(self.dtype, result_.dtype)
+ self.assertAllEqual([2.5, 4.5], result_)
+ self.assertAllEqual([y_actual], vars_args)
+
+ def testWarnings(self):
+ with self.test_session():
+ x = constant_op.constant(self.dtype([0.1, 0.2]))
+ wrapped_fn, _ = variable_utils.externalize_variables_as_args(
+ test_fn, [x], possible_ancestor_vars=[])
+ varscope_ops.get_variable_scope().reuse_variables()
+ with warnings.catch_warnings(record=True) as w:
+ wrapped_fn(self.dtype(2))
+ w = sorted(w, key=lambda w: str(w.message))
+ self.assertEqual(2, len(w))
+ self.assertRegexpMatches(
+ str(w[0].message),
+ r"Variable .* 'y:0' .* not found in bypass dict.")
+ self.assertRegexpMatches(
+ str(w[1].message),
+ r"Variable .* 'z:0' .* not found in bypass dict.")
+
+ def testExceptions(self):
+ with self.test_session():
+ x = constant_op.constant(self.dtype([0.1, 0.2]))
+ wrapped_fn, _ = variable_utils.externalize_variables_as_args(
+ test_fn,
+ [x],
+ possible_ancestor_vars=[],
+ assert_variable_override=True)
+ varscope_ops.get_variable_scope().reuse_variables()
+ with self.assertRaisesRegexp(ValueError, r"not found"):
+ wrapped_fn(self.dtype(2))
+
+
+class WrapCallableTest16(test.TestCase, _WrapCallableTest):
+ dtype = np.float16
+
+
+class WrapCallableTest32(test.TestCase, _WrapCallableTest):
+ dtype = np.float32
+
+
+class WrapCallableTest64(test.TestCase, _WrapCallableTest):
+ dtype = np.float64
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py
new file mode 100644
index 0000000000..eadf6f4d5f
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions related to managing `tf.Variable`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.variable_utils_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member
+from tensorflow.python.util import all_util
+
+_allowed_symbols = [
+ "externalize_variables_as_args",
+]
+
+all_util.remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py
new file mode 100644
index 0000000000..ca3d75b5bf
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py
@@ -0,0 +1,157 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions related to managing `tf.Variable`s.
+
+@@externalize_variables_as_args
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gradients_impl as gradients_ops
+from tensorflow.python.ops import variable_scope as varscope_ops
+from tensorflow.python.ops import variables as variables_ops
+
+__all__ = [
+ "externalize_variables_as_args",
+]
+
+
+# Cause all warnings to always be triggered.
+# Not having this means subsequent calls wont trigger the warning.
+warnings.simplefilter("always")
+
+
+def externalize_variables_as_args(fn,
+ fn_args=(),
+ ancestor_variables=None,
+ possible_ancestor_vars=None,
+ assert_variable_override=False,
+ name=None):
+ """"Converts variables within a callable into explicit args.
+
+ Makes a new callable from `fn` which has arguments `list(fn_args) +
+ list(ancestor_variables)`. If `ancestor_variables` is not specified, it is
+ inferred by checking which of `possible_ancestor_vars` actually influences the
+ return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`).
+ By default `possible_ancestor_vars` is `tf.trainable_variables() +
+ tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`.
+
+ #### Examples:
+
+ ```python
+ num_samples = 2
+ num_dims = 1
+ dtype = np.float32
+
+ def foo(x):
+ x = tf.convert_to_tensor(x, dtype=dtype, name="x")
+ s = x.shape.as_list()
+ y = tf.get_variable(
+ name="y",
+ dtype=dtype,
+ initializer=np.arange(np.prod(s)).reshape(s).astype(dtype))
+ return x + y
+
+ x = tf.constant(dtype([0.1, 0.2]))
+
+ wrapped_foo, discovered_ancestor_variables = (
+ externalize_variables_as_args(foo, [x]))
+
+ new_x = dtype([[1.], [2.]])
+ new_y = dtype([[3.], [4.]])
+ new_result = wrapped_foo(new_x, new_y)
+ # ==> [[4.], [6.]]
+
+ discovered_ancestor_variables == [tf.get_variable("y", dtype)]
+ # ==> [True]
+ ```
+
+ Args:
+ fn: Python callable which returns a `Tensor` and accepts `*fn_args`.
+ fn_args: Python list of args to `fn`. Represents dummy arguments passed to
+ `fn` to trace its execution; actual values are unimportant. These args are
+ only used to construct the output of `fn` and to resolve the ancestor
+ `tf.Variable`s.
+ Default value: `()` (i.e., `fn` takes no args).
+ ancestor_variables: Python list of `tf.Variable`s. When `None` the list is
+ expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing
+ the `ancestor_variables` the internal call to `fn` is avoided.
+ Default value: `None` (i.e., `tf.Variable` dependencies are discovered).
+ possible_ancestor_vars: Python list of possible `tf.Variable`s which might
+ be a dependency of computing `fn(*fn_args)`.
+ Default value: `None` (i.e., expanded as described above).
+ assert_variable_override: Python `bool` indicating that not finding a
+ `tf.Variable` in the override list is an exception.
+ Default value: `False` (i.e., missing a `Variable` triggers a `warning`).
+ name: Python `str` name prefixed to Ops created by this function.
+ Default value: `None` (i.e., "externalize_variables_as_args").
+
+ Returns:
+ wrapped_fn: Python callable taking arguments like
+ `*(list(fn_args) + discovered_ancestor_variables)`.
+ discovered_ancestor_variables: Python list of `tf.Variable`s known to be a
+ dependency of `fn(*fn_args)`.
+
+ Raises:
+ ValueError: if `assert_variable_override` is `True` and `Variable` is
+ requested but not overridden.
+ """
+ def _make_bypassing_custom_getter_fn(new_var_dict):
+ """Return dict value rather than what would otherwise be dict key."""
+ def _custom_getter(getter, *args, **kwargs):
+ v = getter(*args, **kwargs)
+ new_v = new_var_dict.get(v, None)
+ if new_v is None:
+ msg = "Variable \"{}\" not found in bypass dict.".format(v)
+ if assert_variable_override:
+ raise ValueError(msg)
+ warnings.warn(msg)
+ return v
+ return new_v
+ return _custom_getter
+
+ with ops.name_scope(name, "externalize_variables_as_args"):
+ if ancestor_variables is not None and not ancestor_variables:
+ return fn, ()
+ if ancestor_variables is None:
+ y = fn(*fn_args) # Side-effect: adds trainable vars.
+ if possible_ancestor_vars is None:
+ possible_ancestor_vars = (
+ variables_ops.trainable_variables() +
+ ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ # TODO(b/72873296): Add a dedicated op for identifying ancestors.
+ ancestors = [v for g, v
+ in zip(gradients_ops.gradients(y, possible_ancestor_vars),
+ possible_ancestor_vars)
+ if g is not None]
+ ancestor_variables = sorted(ancestors, key=lambda v: v.name)
+ n = len(fn_args)
+ def _fn(*args):
+ with ops.name_scope("wrapped_fn"):
+ vars_dict = dict(
+ (k, ops.convert_to_tensor(
+ v, dtype=k.dtype.base_dtype, name=k.op.name))
+ for k, v in zip(ancestor_variables, args[n:]))
+ with varscope_ops.variable_scope(
+ varscope_ops.get_variable_scope(),
+ reuse=True,
+ custom_getter=_make_bypassing_custom_getter_fn(vars_dict)):
+ return fn(*args[:n])
+ return _fn, ancestor_variables