aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
blob: 7cb8ef06f936a11940f24d5b5f78bd7a3cdd01a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stochastic gradient estimators.

These functions are meant to be used in conjuction with `StochasticTensor`
(`loss_fn` parameter) and `surrogate_loss`.

See Gradient Estimation Using Stochastic Computation Graphs
(http://arxiv.org/abs/1506.05254) by Schulman et al., eq. 1 and section 4, for
mathematical details.

## Score function estimator

The score function is an unbiased estimator of the gradient of `E_p(x)[f(x)]`,
where `f(x)` can be considered to be a "loss" term. It is computed as
`E_p(x)[f(x) grad(log p(x))]`. A constant `b`, referred to here as the
"baseline", can be subtracted from `f(x)` without affecting the expectation. The
term `(f(x) - b)` is referred to here as the "advantage".

Note that the methods defined in this module actually compute the integrand of
the score function, such that when taking the gradient, the true score function
is computed.

@@score_function
@@get_score_function_with_baseline
@@get_score_function_with_constant_baseline
@@get_score_function_with_advantage

## Baseline functions

Baselines reduce the variance of Monte Carlo estimate of an expectation. The
baseline for a stochastic node can be a function of all non-influenced nodes
(see section 4 of Schulman et al., linked above). Baselines are also known as
"control variates."

In the context of a MC estimate of `E_p(x)[f(x) - b]`, baseline functions have
the signature `(st, fx) => Tensor`, where `st` is a `StochasticTensor` backed by
the distribution `p(x)` and `fx` is the influenced loss.

@@get_mean_baseline

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import training
from tensorflow.python.util.all_util import make_all


def score_function(dist_tensor, value, loss, baseline=None,
                   name="ScoreFunction"):
  """Score function estimator.

  Computes the integrand of the score function with a baseline:
  `p.log_prob(value) * (loss - baseline)`.

  It will add a `stop_gradient` to the advantage `(loss - baseline)`.

  Args:
    dist_tensor: `DistributionTensor` p(x).
    value: `Tensor` x. Samples from p(x).
    loss: `Tensor`.
    baseline: `Tensor` broadcastable to `loss`.
    name: name to prepend ops with.

  Returns:
    `Tensor` `p.log_prob(x) * (loss - b)`. Taking the gradient yields the score
    function estimator.
  """
  with ops.name_scope(name, values=[value, loss, baseline]):
    value = ops.convert_to_tensor(value)
    loss = ops.convert_to_tensor(loss)
    if baseline is not None:
      baseline = ops.convert_to_tensor(baseline)
      advantage = loss - baseline
    else:
      advantage = loss

    advantage = array_ops.stop_gradient(advantage)
    return dist_tensor.distribution.log_prob(value) * advantage


def get_score_function_with_advantage(advantage_fn=None,
                                      name="ScoreFunctionWithAdvantage"):
  """Score function estimator with advantage function.

  Args:
    advantage_fn: callable that takes the `DistributionTensor` and the
      downstream `loss` and returns a `Tensor` advantage
      (e.g. `loss - baseline`).
    name: name to prepend ops with.

  Returns:
    Callable score function estimator that takes the `DistributionTensor`, the
    sampled `value`, and the downstream `loss`, and uses the provided advantage.
  """

  def score_function_with_advantage(dist_tensor, value, loss):
    with ops.name_scope(name, values=[value, loss]):
      advantage = advantage_fn(dist_tensor, loss)
      advantage = array_ops.stop_gradient(advantage)
      return dist_tensor.distribution.log_prob(value) * advantage

  return score_function_with_advantage


def get_score_function_with_constant_baseline(baseline, name="ScoreFunction"):
  """Score function estimator with constant baseline.

  Args:
    baseline: `Tensor` to be subtracted from loss.
    name: name to prepend ops with.

  Returns:
    Callable score function estimator that takes the `DistributionTensor`, the
    sampled `value`, and the downstream `loss`, and subtracts the provided
    `baseline` from the `loss`.
  """

  def score_function_with_constant_baseline(dist_tensor, value, loss):
    return score_function(dist_tensor, value, loss, baseline, name)

  return score_function_with_constant_baseline


def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"):
  """Score function estimator with baseline function.

  Args:
    baseline_fn: callable that takes the `DistributionTensor` and the downstream
      `loss` and returns a `Tensor` baseline to be subtracted from the `loss`.
      If None, defaults to `get_mean_baseline`, which is an EMA of the loss.
    name: name to prepend ops with.

  Returns:
    Callable score function estimator that takes the `DistributionTensor`, the
    sampled `value`, and the downstream `loss`, and subtracts the provided
    `baseline` from the `loss`.
  """
  if baseline_fn is None:
    baseline_fn = get_mean_baseline()

  def score_function_with_baseline(dist_tensor, value, loss):
    with ops.name_scope(name):
      b = baseline_fn(dist_tensor, loss)
      return score_function(dist_tensor, value, loss, b)

  return score_function_with_baseline


def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"):
  """ExponentialMovingAverage baseline.

  Args:
    ema_decay: decay rate for the ExponentialMovingAverage.
    name: name to prepend ops with.

  Returns:
    Callable baseline function that takes the `DistributionTensor` (unused) and
    the downstream `loss`, and returns an EMA of the loss.
  """

  def mean_baseline(_, loss):
    with ops.name_scope(name):
      ema = training.ExponentialMovingAverage(decay=ema_decay)
      reduced_loss = math_ops.reduce_mean(loss)
      update_op = ema.apply([reduced_loss])
      with ops.control_dependencies([update_op]):
        # TODO(rsepassi): Possibly implement the initialization bias correction
        # term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf).
        baseline = ema.average(reduced_loss)
      return baseline

  return mean_baseline


__all__ = make_all(__name__)