aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-07-19 13:32:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 13:36:46 -0700
commitb977998827180e3e3ba872db28ff211715cfe73a (patch)
tree5c1c024f748152fca91c2968d48284628b2e43a2
parent831a3a46d34e9bd305555889392c016690ff9bc4 (diff)
Add functions for specifying a custom gradient.
PiperOrigin-RevId: 162527721
-rw-r--r--tensorflow/contrib/bayesflow/BUILD27
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py10
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py157
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/custom_grad.py34
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py110
5 files changed, 334 insertions, 4 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 71828f084d..1cd6e64b32 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -42,9 +42,13 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python/ops/distributions",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
@@ -52,6 +56,26 @@ cuda_py_test(
)
cuda_py_test(
+ name = "custom_grad_test",
+ size = "small",
+ srcs = ["python/kernel_tests/custom_grad_test.py"],
+ additional_deps = [
+ ":bayesflow_py",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
name = "entropy_test",
size = "medium",
srcs = ["python/kernel_tests/entropy_test.py"],
@@ -99,12 +123,15 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python/ops/distributions",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
+ "//tensorflow/python:random_seed",
],
)
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index 65a17d742c..15c1614a67 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -22,6 +22,7 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence
+from tensorflow.contrib.bayesflow.python.ops import custom_grad
from tensorflow.contrib.bayesflow.python.ops import entropy
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators
@@ -34,9 +35,10 @@ from tensorflow.contrib.bayesflow.python.ops import variational_inference
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['csiszar_divergence', 'entropy', 'monte_carlo',
- 'special_math', 'stochastic_gradient_estimators',
- 'stochastic_graph', 'stochastic_tensor',
- 'stochastic_variables', 'variational_inference']
+_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy',
+ 'monte_carlo', 'special_math',
+ 'stochastic_gradient_estimators', 'stochastic_graph',
+ 'stochastic_tensor', 'stochastic_variables',
+ 'variational_inference']
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
new file mode 100644
index 0000000000..5f8f7f692c
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Custom Gradient Ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.bayesflow.python.ops import custom_grad_impl
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+cg = custom_grad_impl
+
+
+class CustomGradientTest(test.TestCase):
+
+ def test_works_correctly(self):
+ with self.test_session() as sess:
+ f = lambda x: x**2 / 2
+ g = lambda x: (x - 1)**3 / 3
+ x_ = np.linspace(-100, 100, int(1e4)) + [0.]
+
+ x = constant_op.constant(x_)
+ fx = cg.custom_gradient(f(x), g(x), x)
+ gx = gradients_impl.gradients(fx, x)[0]
+ [fx_, gx_] = sess.run([fx, gx])
+
+ self.assertAllEqual(f(x_), fx_)
+ self.assertAllEqual(g(x_), gx_)
+
+ def test_works_correctly_both_f_g_zero(self):
+ with self.test_session() as sess:
+ f = lambda x: x**2 / 2
+ g = lambda x: x**3 / 3
+ x_ = np.linspace(-100, 100, int(1e4)) + [0.]
+
+ x = constant_op.constant(x_)
+ fx = cg.custom_gradient(f(x), g(x), x)
+ gx = gradients_impl.gradients(fx, x)[0]
+ [fx_, gx_] = sess.run([fx, gx])
+
+ self.assertAllEqual(f(x_), fx_)
+ self.assertAllEqual(g(x_), gx_)
+
+ def test_works_correctly_vector_of_vars(self):
+ with self.test_session() as sess:
+ x = variable_scope.get_variable(
+ name="x",
+ shape=[],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(2))
+ y = variable_scope.get_variable(
+ name="y",
+ shape=[],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(3))
+ sess.run([variables.global_variables_initializer()])
+
+ f = lambda z: z[0] * z[1]
+ g = lambda z: z[0]**2 * z[1]**2 / 2
+
+ z = array_ops.stack([x, y])
+ fz = cg.custom_gradient(f(z), g(z), z, axis=0)
+ gz = gradients_impl.gradients(fz, variables.trainable_variables())
+ [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]])
+
+ self.assertEqual(f(z_), fz_)
+ self.assertEqual(g(z_), gx_)
+ self.assertEqual(g(z_), gy_)
+
+ def test_works_correctly_side_vars(self):
+ with self.test_session() as sess:
+ x_ = np.float32(2.1) # Adding extra tenth to force imprecision.
+ y_ = np.float32(3.1)
+ x = variable_scope.get_variable(
+ name="x",
+ shape=[],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(x_))
+ y = variable_scope.get_variable(
+ name="y",
+ shape=[],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(y_))
+ sess.run([variables.global_variables_initializer()])
+
+ f = lambda x: x * y
+ g = lambda z: math_ops.square(x) * y
+
+ fx = cg.custom_gradient(f(x), g(x), x)
+ gx = gradients_impl.gradients(fx, variables.trainable_variables())
+ [x_, fx_, gx_] = sess.run([x, fx, gx[0]])
+ gy_ = gx[1]
+
+ self.assertEqual(x_ * y_, fx_)
+ self.assertEqual(np.square(x_) * y_, gx_)
+ self.assertEqual(None, gy_)
+
+ def test_works_correctly_fx_gx_manually_stopped(self):
+ with self.test_session() as sess:
+ x_ = np.float32(2.1) # Adding extra tenth to force imprecision.
+ y_ = np.float32(3.1)
+ x = variable_scope.get_variable(
+ name="x",
+ shape=[],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(x_))
+ y = variable_scope.get_variable(
+ name="y",
+ shape=[],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(y_))
+ sess.run([variables.global_variables_initializer()])
+
+ stop = array_ops.stop_gradient # For readability.
+
+ # Basically we need to stop the `x` portion of `f`. And when we supply the
+ # arg to `custom_gradient` we need to stop the complement, i.e., the `y`
+ # part.
+ f = lambda x: stop(x) * y
+ g = lambda x: stop(math_ops.square(x)) * y
+ fx = cg.custom_gradient(f(x), g(x), x + stop(y),
+ fx_gx_manually_stopped=True)
+
+ gx = gradients_impl.gradients(fx, variables.trainable_variables())
+ [x_, fx_, gx_, gy_] = sess.run([x, fx, gx[0], gx[1]])
+
+ self.assertEqual(x_ * y_, fx_)
+ self.assertEqual(np.square(x_) * y_, gx_)
+ self.assertEqual(x_, gy_)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad.py
new file mode 100644
index 0000000000..ca1ecb9c40
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad.py
@@ -0,0 +1,34 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for specifying custom gradients.
+
+See ${python/contrib.bayesflow.custom_gradient}.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.custom_grad_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'custom_gradient',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
new file mode 100644
index 0000000000..ee3719232d
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
@@ -0,0 +1,110 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for specifying custom gradients.
+
+@@custom_gradient
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+__all__ = [
+ "custom_gradient",
+]
+
+
+def custom_gradient(fx, gx, x, axis=(),
+ fx_gx_manually_stopped=False,
+ name=None):
+ """Enables specifying a custom gradient.
+
+ This function works by clever application of `stop_gradient`. I.e., observe
+ that:
+
+ ```none
+ h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x))
+ ```
+
+ is such that `h(x) = stop(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).`
+
+ In addition to scalar-domain/scalar-range functions, this function also
+ supports tensor-domain/scalar-range functions. However, in the latter case it
+ is necessary to reduce `x` to a scalar. This can be done by indicating the
+ `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior
+ to calling this function.
+
+ Partial Custom Gradient:
+
+ Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy =
+ None`. This is because a `Tensor` cannot have only a portion of its gradient
+ stopped. To circumvent this issue, one must manually `stop_gradient` the
+ relevant portions of `f`, `g`. For example see the unit-test,
+ `test_works_correctly_fx_gx_manually_stopped`.
+
+ Args:
+ fx: `Tensor`. Output of function evaluated at `x`.
+ gx: `Tensor`. Gradient of function evaluated at `x`.
+ x: `Tensor`. Point of evaluation for `f, g`.
+ axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain
+ of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range.
+ If `None` `f` is assumed to render one scalar given all of `x`. Otherwise
+ `f` is assumed to output one scalar for each of `axis` dimensions of `x`.
+ fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually
+ have `stop_gradient` applied.
+ name: Python `str` name prefixed to Ops created by this function.
+
+ Returns:
+ fx: Floating-type `Tensor` equal to `f(x)` but which has gradient
+ `stop_gradient(g(x))`.
+ """
+ with ops.name_scope(name, "custom_gradient", [fx, gx, x]):
+ fx = ops.convert_to_tensor(fx, name="fx")
+ # We don't want to bother eagerly computing `gx` since we may not even need
+ # it.
+ with ops.control_dependencies([fx]):
+ gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx")
+ gx = array_ops.identity(gx, name="gx")
+ # Proof of correctness:
+ #
+ # f(x) = x * stop[gx] + stop[fx - x * gx]
+ # = stop[fx]
+ #
+ # g(x) = grad[fx]
+ # = stop[gx] + grad[stop[fx - x * gx]]
+ # = stop[gx] + 0
+ #
+ # Notice that when x is zero it still works:
+ # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
+ #
+ # The proof is similar for the tensor-domain case, except that `x` is
+ # replaced by `reduce_sum(x)`.
+ sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x")
+ if not fx_gx_manually_stopped:
+ fx = array_ops.stop_gradient(fx)
+ gx = array_ops.stop_gradient(gx)
+ # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write
+ # the code this way, rather than, e.g.,
+ # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
+ # For more discussion regarding the relevant portions of the IEEE754
+ # standard, see the StackOverflow question,
+ # "Is there a floating point value of x, for which x-x == 0 is false?"
+ # http://stackoverflow.com/q/2686644
+ return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx