aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-03-12 11:29:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 11:33:33 -0700
commit617d1f01d60b677536f988be35dc4f02885e6f1e (patch)
tree0cf0d88465195137e3f96650bdb86ba037a9f8f9 /tensorflow/contrib/bayesflow
parent402fb8c97db05b51587c6fc999c690d548fd4496 (diff)
Improve usability of `tf.contrib.bayesflow.custom_gradient` by removing need for `axis` arg and support taking lists.
PiperOrigin-RevId: 188751894
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py122
2 files changed, 76 insertions, 48 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
index a95df31ac1..1250765d09 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
@@ -83,7 +83,7 @@ class CustomGradientTest(test.TestCase):
g = lambda z: z[0]**2 * z[1]**2 / 2
z = array_ops.stack([x, y])
- fz = cg.custom_gradient(f(z), g(z), z, axis=0)
+ fz = cg.custom_gradient(f(z), g(z), z)
gz = gradients_impl.gradients(fz, variables.trainable_variables())
[z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]])
diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
index d44fe6529a..927cc28f67 100644
--- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
@@ -24,32 +24,38 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
__all__ = [
- "custom_gradient",
+ 'custom_gradient',
]
-def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
- name=None):
- """Enables specifying a custom gradient.
+def is_list_like(x):
+ return isinstance(x, (tuple, list))
+
+
+def identity(x, dtype=None, name=None):
+ return array_ops.identity(ops.convert_to_tensor(
+ x, dtype=dtype, name=name), name=name)
+
+
+def custom_gradient(fx, gx, x, fx_gx_manually_stopped=False, name=None):
+ """Embeds a custom gradient into a `Tensor`.
This function works by clever application of `stop_gradient`. I.e., observe
that:
```none
- h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x))
+ h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
```
- is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] =
- stop_gradient(g(x)).`
+ is such that `h(x) == stop_gradient(f(x))` and
+ `grad[h(x), x] == stop_gradient(g(x)).`
In addition to scalar-domain/scalar-range functions, this function also
- supports tensor-domain/scalar-range functions. However, in the latter case it
- is necessary to reduce `x` to a scalar. This can be done by indicating the
- `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior
- to calling this function.
+ supports tensor-domain/scalar-range functions.
Partial Custom Gradient:
@@ -61,12 +67,8 @@ def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
Args:
fx: `Tensor`. Output of function evaluated at `x`.
- gx: `Tensor`. Gradient of function evaluated at `x`.
- x: `Tensor`. Point of evaluation for `f, g`.
- axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain
- of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range.
- If `None` `f` is assumed to render one scalar given all of `x`. Otherwise
- `f` is assumed to output one scalar for each of `axis` dimensions of `x`.
+ gx: `Tensor` or list of `Tensor`s. Gradient of function at (each) `x`.
+ x: `Tensor` or list of `Tensor`s. Args of evaluation for `f`.
fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually
have `stop_gradient` applied.
name: Python `str` name prefixed to Ops created by this function.
@@ -75,36 +77,62 @@ def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
fx: Floating-type `Tensor` equal to `f(x)` but which has gradient
`stop_gradient(g(x))`.
"""
- with ops.name_scope(name, "custom_gradient", [fx, gx, x]):
- fx = ops.convert_to_tensor(fx, name="fx")
+ def maybe_stop(x):
+ if fx_gx_manually_stopped:
+ return x
+ return array_ops.stop_gradient(x)
+ with ops.name_scope(name, 'custom_gradient', [fx, gx, x]):
+ fx = ops.convert_to_tensor(fx, name='fx')
# We don't want to bother eagerly computing `gx` since we may not even need
# it.
with ops.control_dependencies([fx]):
- gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx")
- gx = array_ops.identity(gx, name="gx")
- # Proof of correctness:
- #
- # f(x) = x * stop[gx] + stop[fx - x * gx]
- # = stop[fx]
- #
- # g(x) = grad[fx]
- # = stop[gx] + grad[stop[fx - x * gx]]
- # = stop[gx] + 0
- #
- # Notice that when x is zero it still works:
- # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
- #
- # The proof is similar for the tensor-domain case, except that `x` is
- # replaced by `reduce_sum(x)`.
- sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x")
- if not fx_gx_manually_stopped:
- fx = array_ops.stop_gradient(fx)
- gx = array_ops.stop_gradient(gx)
- # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write
- # the code this way, rather than, e.g.,
- # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
- # For more discussion regarding the relevant portions of the IEEE754
- # standard, see the StackOverflow question,
- # "Is there a floating point value of x, for which x-x == 0 is false?"
- # http://stackoverflow.com/q/2686644
- return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx
+ if is_list_like(x):
+ x = [identity(x_, name='x') for x_ in x]
+ else:
+ x = [identity(x, name='x')]
+
+ if is_list_like(gx):
+ gx = [identity(gx_, dtype=fx.dtype, name='gx')
+ for gx_ in gx]
+ else:
+ gx = [identity(gx, dtype=fx.dtype, name='gx')]
+
+ override_grad = []
+ for x_, gx_ in zip(x, gx):
+ # Observe: tf.gradients(f(x), x)[i].shape == x[i].shape
+ # thus we check that the user is supplying correct shapes.
+ equal_shape = check_ops.assert_equal(
+ array_ops.shape(x_),
+ array_ops.shape(gx_),
+ message='Each `x` must have the same shape as each `gx`.')
+ with ops.control_dependencies([equal_shape]):
+ # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to
+ # write the code this way, rather than, e.g.,
+ # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
+ # For more discussion regarding the relevant portions of the IEEE754
+ # standard, see the StackOverflow question,
+ # "Is there a floating point value of x, for which x-x == 0 is false?"
+ # http://stackoverflow.com/q/2686644
+ zeros_like_x_ = x_ - array_ops.stop_gradient(x_)
+ override_grad.append(math_ops.reduce_sum(
+ maybe_stop(gx_) * zeros_like_x_))
+ override_grad = sum(override_grad)
+ override_grad /= math_ops.cast(array_ops.size(fx),
+ dtype=fx.dtype.base_dtype)
+
+ # Proof of correctness:
+ #
+ # f(x) = x * stop[gx] + stop[fx - x * gx]
+ # = stop[fx]
+ #
+ # g(x) = grad[fx]
+ # = stop[gx] + grad[stop[fx - x * gx]]
+ # = stop[gx] + 0
+ #
+ # Notice that when x is zero it still works:
+ # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
+ #
+ # The proof is similar for the tensor-domain case, except that we
+ # `reduce_sum` the `stop[gx] * (x - stop[x])` then rescale by
+ # `tf.size(fx)` since this reduced version is broadcast to `fx`.
+ return maybe_stop(fx) + override_grad