diff options
author | Joshua V. Dillon <jvdillon@google.com> | 2018-03-12 11:29:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-12 11:33:33 -0700 |
commit | 617d1f01d60b677536f988be35dc4f02885e6f1e (patch) | |
tree | 0cf0d88465195137e3f96650bdb86ba037a9f8f9 /tensorflow/contrib/bayesflow | |
parent | 402fb8c97db05b51587c6fc999c690d548fd4496 (diff) |
Improve usability of `tf.contrib.bayesflow.custom_gradient` by removing need for `axis` arg and support taking lists.
PiperOrigin-RevId: 188751894
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py | 122 |
2 files changed, 76 insertions, 48 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py index a95df31ac1..1250765d09 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py @@ -83,7 +83,7 @@ class CustomGradientTest(test.TestCase): g = lambda z: z[0]**2 * z[1]**2 / 2 z = array_ops.stack([x, y]) - fz = cg.custom_gradient(f(z), g(z), z, axis=0) + fz = cg.custom_gradient(f(z), g(z), z) gz = gradients_impl.gradients(fz, variables.trainable_variables()) [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]]) diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py index d44fe6529a..927cc28f67 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py @@ -24,32 +24,38 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops __all__ = [ - "custom_gradient", + 'custom_gradient', ] -def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, - name=None): - """Enables specifying a custom gradient. +def is_list_like(x): + return isinstance(x, (tuple, list)) + + +def identity(x, dtype=None, name=None): + return array_ops.identity(ops.convert_to_tensor( + x, dtype=dtype, name=name), name=name) + + +def custom_gradient(fx, gx, x, fx_gx_manually_stopped=False, name=None): + """Embeds a custom gradient into a `Tensor`. This function works by clever application of `stop_gradient`. I.e., observe that: ```none - h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) + h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x)) ``` - is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = - stop_gradient(g(x)).` + is such that `h(x) == stop_gradient(f(x))` and + `grad[h(x), x] == stop_gradient(g(x)).` In addition to scalar-domain/scalar-range functions, this function also - supports tensor-domain/scalar-range functions. However, in the latter case it - is necessary to reduce `x` to a scalar. This can be done by indicating the - `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior - to calling this function. + supports tensor-domain/scalar-range functions. Partial Custom Gradient: @@ -61,12 +67,8 @@ def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, Args: fx: `Tensor`. Output of function evaluated at `x`. - gx: `Tensor`. Gradient of function evaluated at `x`. - x: `Tensor`. Point of evaluation for `f, g`. - axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain - of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range. - If `None` `f` is assumed to render one scalar given all of `x`. Otherwise - `f` is assumed to output one scalar for each of `axis` dimensions of `x`. + gx: `Tensor` or list of `Tensor`s. Gradient of function at (each) `x`. + x: `Tensor` or list of `Tensor`s. Args of evaluation for `f`. fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually have `stop_gradient` applied. name: Python `str` name prefixed to Ops created by this function. @@ -75,36 +77,62 @@ def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, fx: Floating-type `Tensor` equal to `f(x)` but which has gradient `stop_gradient(g(x))`. """ - with ops.name_scope(name, "custom_gradient", [fx, gx, x]): - fx = ops.convert_to_tensor(fx, name="fx") + def maybe_stop(x): + if fx_gx_manually_stopped: + return x + return array_ops.stop_gradient(x) + with ops.name_scope(name, 'custom_gradient', [fx, gx, x]): + fx = ops.convert_to_tensor(fx, name='fx') # We don't want to bother eagerly computing `gx` since we may not even need # it. with ops.control_dependencies([fx]): - gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx") - gx = array_ops.identity(gx, name="gx") - # Proof of correctness: - # - # f(x) = x * stop[gx] + stop[fx - x * gx] - # = stop[fx] - # - # g(x) = grad[fx] - # = stop[gx] + grad[stop[fx - x * gx]] - # = stop[gx] + 0 - # - # Notice that when x is zero it still works: - # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] - # - # The proof is similar for the tensor-domain case, except that `x` is - # replaced by `reduce_sum(x)`. - sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x") - if not fx_gx_manually_stopped: - fx = array_ops.stop_gradient(fx) - gx = array_ops.stop_gradient(gx) - # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write - # the code this way, rather than, e.g., - # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. - # For more discussion regarding the relevant portions of the IEEE754 - # standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx + if is_list_like(x): + x = [identity(x_, name='x') for x_ in x] + else: + x = [identity(x, name='x')] + + if is_list_like(gx): + gx = [identity(gx_, dtype=fx.dtype, name='gx') + for gx_ in gx] + else: + gx = [identity(gx, dtype=fx.dtype, name='gx')] + + override_grad = [] + for x_, gx_ in zip(x, gx): + # Observe: tf.gradients(f(x), x)[i].shape == x[i].shape + # thus we check that the user is supplying correct shapes. + equal_shape = check_ops.assert_equal( + array_ops.shape(x_), + array_ops.shape(gx_), + message='Each `x` must have the same shape as each `gx`.') + with ops.control_dependencies([equal_shape]): + # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to + # write the code this way, rather than, e.g., + # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. + # For more discussion regarding the relevant portions of the IEEE754 + # standard, see the StackOverflow question, + # "Is there a floating point value of x, for which x-x == 0 is false?" + # http://stackoverflow.com/q/2686644 + zeros_like_x_ = x_ - array_ops.stop_gradient(x_) + override_grad.append(math_ops.reduce_sum( + maybe_stop(gx_) * zeros_like_x_)) + override_grad = sum(override_grad) + override_grad /= math_ops.cast(array_ops.size(fx), + dtype=fx.dtype.base_dtype) + + # Proof of correctness: + # + # f(x) = x * stop[gx] + stop[fx - x * gx] + # = stop[fx] + # + # g(x) = grad[fx] + # = stop[gx] + grad[stop[fx - x * gx]] + # = stop[gx] + 0 + # + # Notice that when x is zero it still works: + # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] + # + # The proof is similar for the tensor-domain case, except that we + # `reduce_sum` the `stop[gx] * (x - stop[x])` then rescale by + # `tf.size(fx)` since this reduced version is broadcast to `fx`. + return maybe_stop(fx) + override_grad |