aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-03-01 17:41:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 17:45:55 -0800
commit4669767c4c6d830c2234c3ba15944a362b08fa14 (patch)
tree0999730990ad24f1b6d592249780768201e030eb /tensorflow/contrib/bayesflow
parent8a591af6854ee1b010d82d262072b5d3b2cdf7cc (diff)
Add util which creates Python callable with tf.Variables explicitly as
arguments. PiperOrigin-RevId: 187561302
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD17
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py135
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/variable_utils.py29
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py157
5 files changed, 0 insertions, 340 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 270c309ec3..3592cff90b 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -252,23 +252,6 @@ cuda_py_test(
)
cuda_py_test(
- name = "variable_utils_test",
- size = "small",
- srcs = ["python/kernel_tests/variable_utils_test.py"],
- additional_deps = [
- ":bayesflow_py",
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
-)
-
-cuda_py_test(
name = "variational_sgd_optimizer_test",
size = "small",
srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"],
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index 528c4fbacd..c411026346 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
from tensorflow.contrib.bayesflow.python.ops import optimizers
-from tensorflow.contrib.bayesflow.python.ops import variable_utils
# pylint: enable=unused-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
@@ -49,7 +48,6 @@ _allowed_symbols = [
'optimizers',
'special_math',
'stochastic_variables',
- 'variable_utils',
'variational_inference',
]
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py
deleted file mode 100644
index f978cf8641..0000000000
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for utility functions related to managing `tf.Variable`s."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-import numpy as np
-
-from tensorflow.contrib.bayesflow.python.ops import variable_utils
-
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import variable_scope as varscope_ops
-from tensorflow.python.ops import variables as variables_ops
-from tensorflow.python.platform import test
-
-
-def test_fn(x):
- x = ops.convert_to_tensor(x, name="x")
- dtype = x.dtype.as_numpy_dtype
- s = x.shape.as_list()
- z = varscope_ops.get_variable(
- name="z",
- dtype=dtype,
- initializer=np.arange(np.prod(s)).reshape(s).astype(dtype))
- y = varscope_ops.get_variable(
- name="y",
- dtype=dtype,
- initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2)
- return x + y + z
-
-
-class _WrapCallableTest(object):
-
- def testDefaultArgsWorkCorrectly(self):
- with self.test_session():
- x = constant_op.constant(self.dtype([0.1, 0.2]))
- wrapped_fn, vars_args = variable_utils.externalize_variables_as_args(
- test_fn, [x])
-
- varscope_ops.get_variable_scope().reuse_variables()
-
- result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5)
-
- y_actual = varscope_ops.get_variable("y", dtype=self.dtype)
- z_actual = varscope_ops.get_variable("z", dtype=self.dtype)
-
- variables_ops.global_variables_initializer().run()
- result_ = result.eval()
-
- self.assertEqual(self.dtype, result_.dtype)
- self.assertAllEqual([5.5, 6.5, 7.5], result_)
- self.assertAllEqual([y_actual, z_actual], vars_args)
-
- def testNonDefaultArgsWorkCorrectly(self):
- with self.test_session():
- x = constant_op.constant(self.dtype([0.1, 0.2]))
-
- _ = test_fn(self.dtype([0., 0.])) # Needed to create vars.
- varscope_ops.get_variable_scope().reuse_variables()
-
- y_actual = varscope_ops.get_variable("y", dtype=self.dtype)
-
- wrapped_fn, vars_args = variable_utils.externalize_variables_as_args(
- test_fn, [x], possible_ancestor_vars=[y_actual])
-
- result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y
-
- variables_ops.global_variables_initializer().run()
- result_ = result.eval()
-
- self.assertEqual(self.dtype, result_.dtype)
- self.assertAllEqual([2.5, 4.5], result_)
- self.assertAllEqual([y_actual], vars_args)
-
- def testWarnings(self):
- with self.test_session():
- x = constant_op.constant(self.dtype([0.1, 0.2]))
- wrapped_fn, _ = variable_utils.externalize_variables_as_args(
- test_fn, [x], possible_ancestor_vars=[])
- varscope_ops.get_variable_scope().reuse_variables()
- with warnings.catch_warnings(record=True) as w:
- wrapped_fn(self.dtype(2))
- w = sorted(w, key=lambda w: str(w.message))
- self.assertEqual(2, len(w))
- self.assertRegexpMatches(
- str(w[0].message),
- r"Variable .* 'y:0' .* not found in bypass dict.")
- self.assertRegexpMatches(
- str(w[1].message),
- r"Variable .* 'z:0' .* not found in bypass dict.")
-
- def testExceptions(self):
- with self.test_session():
- x = constant_op.constant(self.dtype([0.1, 0.2]))
- wrapped_fn, _ = variable_utils.externalize_variables_as_args(
- test_fn,
- [x],
- possible_ancestor_vars=[],
- assert_variable_override=True)
- varscope_ops.get_variable_scope().reuse_variables()
- with self.assertRaisesRegexp(ValueError, r"not found"):
- wrapped_fn(self.dtype(2))
-
-
-class WrapCallableTest16(test.TestCase, _WrapCallableTest):
- dtype = np.float16
-
-
-class WrapCallableTest32(test.TestCase, _WrapCallableTest):
- dtype = np.float32
-
-
-class WrapCallableTest64(test.TestCase, _WrapCallableTest):
- dtype = np.float64
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py
deleted file mode 100644
index eadf6f4d5f..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions related to managing `tf.Variable`s."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# go/tf-wildcard-import
-from tensorflow.contrib.bayesflow.python.ops.variable_utils_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member
-from tensorflow.python.util import all_util
-
-_allowed_symbols = [
- "externalize_variables_as_args",
-]
-
-all_util.remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py
deleted file mode 100644
index ca3d75b5bf..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions related to managing `tf.Variable`s.
-
-@@externalize_variables_as_args
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gradients_impl as gradients_ops
-from tensorflow.python.ops import variable_scope as varscope_ops
-from tensorflow.python.ops import variables as variables_ops
-
-__all__ = [
- "externalize_variables_as_args",
-]
-
-
-# Cause all warnings to always be triggered.
-# Not having this means subsequent calls wont trigger the warning.
-warnings.simplefilter("always")
-
-
-def externalize_variables_as_args(fn,
- fn_args=(),
- ancestor_variables=None,
- possible_ancestor_vars=None,
- assert_variable_override=False,
- name=None):
- """"Converts variables within a callable into explicit args.
-
- Makes a new callable from `fn` which has arguments `list(fn_args) +
- list(ancestor_variables)`. If `ancestor_variables` is not specified, it is
- inferred by checking which of `possible_ancestor_vars` actually influences the
- return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`).
- By default `possible_ancestor_vars` is `tf.trainable_variables() +
- tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`.
-
- #### Examples:
-
- ```python
- num_samples = 2
- num_dims = 1
- dtype = np.float32
-
- def foo(x):
- x = tf.convert_to_tensor(x, dtype=dtype, name="x")
- s = x.shape.as_list()
- y = tf.get_variable(
- name="y",
- dtype=dtype,
- initializer=np.arange(np.prod(s)).reshape(s).astype(dtype))
- return x + y
-
- x = tf.constant(dtype([0.1, 0.2]))
-
- wrapped_foo, discovered_ancestor_variables = (
- externalize_variables_as_args(foo, [x]))
-
- new_x = dtype([[1.], [2.]])
- new_y = dtype([[3.], [4.]])
- new_result = wrapped_foo(new_x, new_y)
- # ==> [[4.], [6.]]
-
- discovered_ancestor_variables == [tf.get_variable("y", dtype)]
- # ==> [True]
- ```
-
- Args:
- fn: Python callable which returns a `Tensor` and accepts `*fn_args`.
- fn_args: Python list of args to `fn`. Represents dummy arguments passed to
- `fn` to trace its execution; actual values are unimportant. These args are
- only used to construct the output of `fn` and to resolve the ancestor
- `tf.Variable`s.
- Default value: `()` (i.e., `fn` takes no args).
- ancestor_variables: Python list of `tf.Variable`s. When `None` the list is
- expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing
- the `ancestor_variables` the internal call to `fn` is avoided.
- Default value: `None` (i.e., `tf.Variable` dependencies are discovered).
- possible_ancestor_vars: Python list of possible `tf.Variable`s which might
- be a dependency of computing `fn(*fn_args)`.
- Default value: `None` (i.e., expanded as described above).
- assert_variable_override: Python `bool` indicating that not finding a
- `tf.Variable` in the override list is an exception.
- Default value: `False` (i.e., missing a `Variable` triggers a `warning`).
- name: Python `str` name prefixed to Ops created by this function.
- Default value: `None` (i.e., "externalize_variables_as_args").
-
- Returns:
- wrapped_fn: Python callable taking arguments like
- `*(list(fn_args) + discovered_ancestor_variables)`.
- discovered_ancestor_variables: Python list of `tf.Variable`s known to be a
- dependency of `fn(*fn_args)`.
-
- Raises:
- ValueError: if `assert_variable_override` is `True` and `Variable` is
- requested but not overridden.
- """
- def _make_bypassing_custom_getter_fn(new_var_dict):
- """Return dict value rather than what would otherwise be dict key."""
- def _custom_getter(getter, *args, **kwargs):
- v = getter(*args, **kwargs)
- new_v = new_var_dict.get(v, None)
- if new_v is None:
- msg = "Variable \"{}\" not found in bypass dict.".format(v)
- if assert_variable_override:
- raise ValueError(msg)
- warnings.warn(msg)
- return v
- return new_v
- return _custom_getter
-
- with ops.name_scope(name, "externalize_variables_as_args"):
- if ancestor_variables is not None and not ancestor_variables:
- return fn, ()
- if ancestor_variables is None:
- y = fn(*fn_args) # Side-effect: adds trainable vars.
- if possible_ancestor_vars is None:
- possible_ancestor_vars = (
- variables_ops.trainable_variables() +
- ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
- # TODO(b/72873296): Add a dedicated op for identifying ancestors.
- ancestors = [v for g, v
- in zip(gradients_ops.gradients(y, possible_ancestor_vars),
- possible_ancestor_vars)
- if g is not None]
- ancestor_variables = sorted(ancestors, key=lambda v: v.name)
- n = len(fn_args)
- def _fn(*args):
- with ops.name_scope("wrapped_fn"):
- vars_dict = dict(
- (k, ops.convert_to_tensor(
- v, dtype=k.dtype.base_dtype, name=k.op.name))
- for k, v in zip(ancestor_variables, args[n:]))
- with varscope_ops.variable_scope(
- varscope_ops.get_variable_scope(),
- reuse=True,
- custom_getter=_make_bypassing_custom_getter_fn(vars_dict)):
- return fn(*args[:n])
- return _fn, ancestor_variables