diff options
author | 2017-07-19 13:32:35 -0700 | |
---|---|---|
committer | 2017-07-19 13:36:46 -0700 | |
commit | b977998827180e3e3ba872db28ff211715cfe73a (patch) | |
tree | 5c1c024f748152fca91c2968d48284628b2e43a2 | |
parent | 831a3a46d34e9bd305555889392c016690ff9bc4 (diff) |
Add functions for specifying a custom gradient.
PiperOrigin-RevId: 162527721
-rw-r--r-- | tensorflow/contrib/bayesflow/BUILD | 27 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/__init__.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py | 157 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/custom_grad.py | 34 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py | 110 |
5 files changed, 334 insertions, 4 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 71828f084d..1cd6e64b32 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -42,9 +42,13 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", @@ -52,6 +56,26 @@ cuda_py_test( ) cuda_py_test( + name = "custom_grad_test", + size = "small", + srcs = ["python/kernel_tests/custom_grad_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( name = "entropy_test", size = "medium", srcs = ["python/kernel_tests/entropy_test.py"], @@ -99,12 +123,15 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", ], ) diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 65a17d742c..15c1614a67 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence +from tensorflow.contrib.bayesflow.python.ops import custom_grad from tensorflow.contrib.bayesflow.python.ops import entropy from tensorflow.contrib.bayesflow.python.ops import monte_carlo from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators @@ -34,9 +35,10 @@ from tensorflow.contrib.bayesflow.python.ops import variational_inference from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['csiszar_divergence', 'entropy', 'monte_carlo', - 'special_math', 'stochastic_gradient_estimators', - 'stochastic_graph', 'stochastic_tensor', - 'stochastic_variables', 'variational_inference'] +_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy', + 'monte_carlo', 'special_math', + 'stochastic_gradient_estimators', 'stochastic_graph', + 'stochastic_tensor', 'stochastic_variables', + 'variational_inference'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py new file mode 100644 index 0000000000..5f8f7f692c --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py @@ -0,0 +1,157 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Custom Gradient Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import custom_grad_impl +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +cg = custom_grad_impl + + +class CustomGradientTest(test.TestCase): + + def test_works_correctly(self): + with self.test_session() as sess: + f = lambda x: x**2 / 2 + g = lambda x: (x - 1)**3 / 3 + x_ = np.linspace(-100, 100, int(1e4)) + [0.] + + x = constant_op.constant(x_) + fx = cg.custom_gradient(f(x), g(x), x) + gx = gradients_impl.gradients(fx, x)[0] + [fx_, gx_] = sess.run([fx, gx]) + + self.assertAllEqual(f(x_), fx_) + self.assertAllEqual(g(x_), gx_) + + def test_works_correctly_both_f_g_zero(self): + with self.test_session() as sess: + f = lambda x: x**2 / 2 + g = lambda x: x**3 / 3 + x_ = np.linspace(-100, 100, int(1e4)) + [0.] + + x = constant_op.constant(x_) + fx = cg.custom_gradient(f(x), g(x), x) + gx = gradients_impl.gradients(fx, x)[0] + [fx_, gx_] = sess.run([fx, gx]) + + self.assertAllEqual(f(x_), fx_) + self.assertAllEqual(g(x_), gx_) + + def test_works_correctly_vector_of_vars(self): + with self.test_session() as sess: + x = variable_scope.get_variable( + name="x", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(2)) + y = variable_scope.get_variable( + name="y", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(3)) + sess.run([variables.global_variables_initializer()]) + + f = lambda z: z[0] * z[1] + g = lambda z: z[0]**2 * z[1]**2 / 2 + + z = array_ops.stack([x, y]) + fz = cg.custom_gradient(f(z), g(z), z, axis=0) + gz = gradients_impl.gradients(fz, variables.trainable_variables()) + [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]]) + + self.assertEqual(f(z_), fz_) + self.assertEqual(g(z_), gx_) + self.assertEqual(g(z_), gy_) + + def test_works_correctly_side_vars(self): + with self.test_session() as sess: + x_ = np.float32(2.1) # Adding extra tenth to force imprecision. + y_ = np.float32(3.1) + x = variable_scope.get_variable( + name="x", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(x_)) + y = variable_scope.get_variable( + name="y", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(y_)) + sess.run([variables.global_variables_initializer()]) + + f = lambda x: x * y + g = lambda z: math_ops.square(x) * y + + fx = cg.custom_gradient(f(x), g(x), x) + gx = gradients_impl.gradients(fx, variables.trainable_variables()) + [x_, fx_, gx_] = sess.run([x, fx, gx[0]]) + gy_ = gx[1] + + self.assertEqual(x_ * y_, fx_) + self.assertEqual(np.square(x_) * y_, gx_) + self.assertEqual(None, gy_) + + def test_works_correctly_fx_gx_manually_stopped(self): + with self.test_session() as sess: + x_ = np.float32(2.1) # Adding extra tenth to force imprecision. + y_ = np.float32(3.1) + x = variable_scope.get_variable( + name="x", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(x_)) + y = variable_scope.get_variable( + name="y", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(y_)) + sess.run([variables.global_variables_initializer()]) + + stop = array_ops.stop_gradient # For readability. + + # Basically we need to stop the `x` portion of `f`. And when we supply the + # arg to `custom_gradient` we need to stop the complement, i.e., the `y` + # part. + f = lambda x: stop(x) * y + g = lambda x: stop(math_ops.square(x)) * y + fx = cg.custom_gradient(f(x), g(x), x + stop(y), + fx_gx_manually_stopped=True) + + gx = gradients_impl.gradients(fx, variables.trainable_variables()) + [x_, fx_, gx_, gy_] = sess.run([x, fx, gx[0], gx[1]]) + + self.assertEqual(x_ * y_, fx_) + self.assertEqual(np.square(x_) * y_, gx_) + self.assertEqual(x_, gy_) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad.py new file mode 100644 index 0000000000..ca1ecb9c40 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad.py @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for specifying custom gradients. + +See ${python/contrib.bayesflow.custom_gradient}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.custom_grad_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'custom_gradient', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py new file mode 100644 index 0000000000..ee3719232d --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py @@ -0,0 +1,110 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for specifying custom gradients. + +@@custom_gradient + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + +__all__ = [ + "custom_gradient", +] + + +def custom_gradient(fx, gx, x, axis=(), + fx_gx_manually_stopped=False, + name=None): + """Enables specifying a custom gradient. + + This function works by clever application of `stop_gradient`. I.e., observe + that: + + ```none + h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) + ``` + + is such that `h(x) = stop(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).` + + In addition to scalar-domain/scalar-range functions, this function also + supports tensor-domain/scalar-range functions. However, in the latter case it + is necessary to reduce `x` to a scalar. This can be done by indicating the + `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior + to calling this function. + + Partial Custom Gradient: + + Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy = + None`. This is because a `Tensor` cannot have only a portion of its gradient + stopped. To circumvent this issue, one must manually `stop_gradient` the + relevant portions of `f`, `g`. For example see the unit-test, + `test_works_correctly_fx_gx_manually_stopped`. + + Args: + fx: `Tensor`. Output of function evaluated at `x`. + gx: `Tensor`. Gradient of function evaluated at `x`. + x: `Tensor`. Point of evaluation for `f, g`. + axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain + of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range. + If `None` `f` is assumed to render one scalar given all of `x`. Otherwise + `f` is assumed to output one scalar for each of `axis` dimensions of `x`. + fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually + have `stop_gradient` applied. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + fx: Floating-type `Tensor` equal to `f(x)` but which has gradient + `stop_gradient(g(x))`. + """ + with ops.name_scope(name, "custom_gradient", [fx, gx, x]): + fx = ops.convert_to_tensor(fx, name="fx") + # We don't want to bother eagerly computing `gx` since we may not even need + # it. + with ops.control_dependencies([fx]): + gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx") + gx = array_ops.identity(gx, name="gx") + # Proof of correctness: + # + # f(x) = x * stop[gx] + stop[fx - x * gx] + # = stop[fx] + # + # g(x) = grad[fx] + # = stop[gx] + grad[stop[fx - x * gx]] + # = stop[gx] + 0 + # + # Notice that when x is zero it still works: + # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] + # + # The proof is similar for the tensor-domain case, except that `x` is + # replaced by `reduce_sum(x)`. + sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x") + if not fx_gx_manually_stopped: + fx = array_ops.stop_gradient(fx) + gx = array_ops.stop_gradient(gx) + # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write + # the code this way, rather than, e.g., + # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. + # For more discussion regarding the relevant portions of the IEEE754 + # standard, see the StackOverflow question, + # "Is there a floating point value of x, for which x-x == 0 is false?" + # http://stackoverflow.com/q/2686644 + return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx |