aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
blob: 2260b6b0b0a0bde5ec576998f059d35cf27f77da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and helper functions for Stochastic Computation Graphs.

## Stochastic Computation Graph Classes

@@StochasticTensor
@@DistributionTensor

## Stochastic Computation Value Types

@@MeanValue
@@SampleValue
@@SampleAndReshapeValue
@@value_type
@@get_current_value_type

## Stochastic Computation Graph Helper Functions

@@surrogate_losses
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections
import contextlib
import threading

import six

from tensorflow.contrib import distributions
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging

STOCHASTIC_TENSOR_COLLECTION = "_stochastic_tensor_collection_"


@six.add_metaclass(abc.ABCMeta)
class StochasticTensor(object):
  """Base Class for Tensor-like objects that emit stochastic values."""

  def __init__(self, **kwargs):
    self._inputs = kwargs

    # Add self to this graph's Stochsatic Tensor collection for
    # purposes of later performing correct surrogate loss calculation.
    ops.add_to_collection(STOCHASTIC_TENSOR_COLLECTION, self)

  @abc.abstractproperty
  def name(self):
    pass

  @abc.abstractproperty
  def dtype(self):
    pass

  @abc.abstractproperty
  def graph(self):
    pass

  @abc.abstractproperty
  def input_dict(self):
    pass

  @abc.abstractmethod
  def value(self, name=None):
    pass

  @abc.abstractmethod
  def surrogate_loss(self, sample_losses):
    """Returns the surrogate loss given the list of sample_losses.

    This method is called by `surrogate_losses`.  The input `sample_losses`
    presumably have already had `stop_gradient` applied to them.  This is
    because the surrogate_loss usually provides a monte carlo sample term
    of the form `differentiable_surrogate * sum(sample_losses)` where
    `sample_losses` is considered constant with respect to the input
    for purposes of the gradient.

    Args:
      sample_losses: a list of Tensors, the sample losses downstream of this
        `StochasticTensor`.

    Returns:
      Either either `None` or a `Tensor` whose gradient is the
       score function.
    """
    raise NotImplementedError("surrogate_loss not implemented")

  @staticmethod
  def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False):
    _ = name
    if dtype and not dtype.is_compatible_with(v.dtype):
      raise ValueError(
          "Incompatible type conversion requested to type '%s' for variable "
          "of type '%s'" % (dtype.name, v.dtype.name))
    if as_ref:
      raise ValueError("%s: Ref type is not supported." % v)
    return v.value()


# pylint: disable=protected-access
ops.register_tensor_conversion_function(
    StochasticTensor, StochasticTensor._tensor_conversion_function)
# pylint: enable=protected-access


class _StochasticValueType(object):

  def pushed_above(self, unused_value_type):
    pass

  def popped_above(self, unused_value_type):
    pass

  @abc.abstractproperty
  def stop_gradient(self):
    """Whether the value should be wrapped in stop_gradient.

    StochasticTensors must respect this property.
    """
    pass


class MeanValue(_StochasticValueType):

  def __init__(self, stop_gradient=False):
    self._stop_gradient = stop_gradient

  @property
  def stop_gradient(self):
    return self._stop_gradient


class SampleValue(_StochasticValueType):
  """Draw n samples along a new outer dimension.

  This ValueType draws `n` samples from StochasticTensors run within its
  context, increasing the rank by one along a new outer dimension.

  Example:

  ```python
  mu = tf.zeros((2,3))
  sigma = tf.ones((2, 3))
  with sg.value_type(sg.SampleValue(n=4)):
    dt = sg.DistributionTensor(
      distributions.Normal, mu=mu, sigma=sigma)
  # draws 4 samples each with shape (2, 3) and concatenates
  assertEqual(dt.value().get_shape(), (4, 2, 3))
  ```
  """

  def __init__(self, n=1, stop_gradient=False):
    """Sample `n` times and concatenate along a new outer dimension.

    Args:
      n: A python integer or int32 tensor. The number of samples to take.
      stop_gradient: If `True`, StochasticTensors' values are wrapped in
        `stop_gradient`, to avoid backpropagation through.
    """
    self._n = n
    self._stop_gradient = stop_gradient

  @property
  def n(self):
    return self._n

  @property
  def stop_gradient(self):
    return self._stop_gradient


class SampleAndReshapeValue(_StochasticValueType):
  """Ask the StochasticTensor for n samples and reshape the result.

  Sampling from a StochasticTensor increases the rank of the value by 1
  (because each sample represents a new outer dimension).

  This ValueType requests `n` samples from StochasticTensors run within its
  context that the outer two dimensions are reshaped to intermix the samples
  with the outermost (usually batch) dimension.

  Example:

  ```python
  # mu and sigma are both shaped (2, 3)
  mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
  sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])

  with sg.value_type(sg.SampleAndReshapeValue(n=2)):
    dt = sg.DistributionTensor(
        distributions.Normal, mu=mu, sigma=sigma)

  # sample(2) creates a (2, 2, 3) tensor, and the two outermost dimensions
  # are reshaped into one: the final value is a (4, 3) tensor.
  dt_value = dt.value()
  assertEqual(dt_value.get_shape(), (4, 3))

  dt_value_val = sess.run([dt_value])[0]  # or e.g. run([tf.identity(dt)])[0]
  assertEqual(dt_value_val.shape, (4, 3))
  ```
  """

  def __init__(self, n=1, stop_gradient=False):
    """Sample `n` times and reshape the outer 2 axes so rank does not change.

    Args:
      n: A python integer or int32 tensor.  The number of samples to take.
      stop_gradient: If `True`, StochasticTensors' values are wrapped in
        `stop_gradient`, to avoid backpropagation through.
    """
    self._n = n
    self._stop_gradient = stop_gradient

  @property
  def n(self):
    return self._n

  @property
  def stop_gradient(self):
    return self._stop_gradient


# Keeps track of how a StochasticTensor's value should be accessed.
# Used by value_type and get_current_value_type below.
_STOCHASTIC_VALUE_STACK = collections.defaultdict(list)


@contextlib.contextmanager
def value_type(dist_value_type):
  """Creates a value type context for any StochasticTensor created within.

  Typical usage:

  ```
  with sg.value_type(sg.MeanValue(stop_gradients=True)):
    dt = sg.DistributionTensor(distributions.Normal, mu=mu, sigma=sigma)
  ```

  In the example above, `dt.value()` (or equivalently, `tf.identity(dt)`) will
  be the mean value of the Normal distribution, i.e., `mu` (possibly
  broadcasted to the shape of `sigma`).  Furthermore, because the `MeanValue`
  was marked with `stop_gradients=True`, this value will have been wrapped
  in a `stop_gradients` call to disable any possible backpropagation.

  Args:
    dist_value_type: An instance of `MeanValue`, `SampleAndReshapeValue`, or
      any other stochastic value type.

  Yields:
    A context for `StochasticTensor` objects that controls the
    value created when they are initialized.

  Raises:
    TypeError: if `dist_value_type` is not an instance of a stochastic value
      type.
  """
  if not isinstance(dist_value_type, _StochasticValueType):
    raise TypeError("dist_value_type must be a Distribution Value Type")
  thread_id = threading.current_thread().ident
  stack = _STOCHASTIC_VALUE_STACK[thread_id]
  if stack:
    stack[-1].pushed_above(dist_value_type)
  stack.append(dist_value_type)
  yield
  stack.pop()
  if stack:
    stack[-1].popped_above(dist_value_type)


def get_current_value_type():
  thread_id = threading.current_thread().ident
  if not _STOCHASTIC_VALUE_STACK[thread_id]:
    raise ValueError(
        "No value type currently set for this thread (%s).  Did you forget to "
        "wrap 'with stochastic_graph.value_type(...)'?" % thread_id)
  return _STOCHASTIC_VALUE_STACK[thread_id][-1]


class DistributionTensor(StochasticTensor):
  """The DistributionTensor is a StochasticTensor backed by a distribution.
  """

  def __init__(self, dist_cls, name=None, dist_value_type=None, **dist_args):
    self._dist_cls = dist_cls
    self._dist_args = dist_args
    if dist_value_type is not None:
      # We want to enforce a value type here, but use the value_type()
      # context manager to enforce some error checking.
      with value_type(dist_value_type):
        self._value_type = get_current_value_type()
    else:
      self._value_type = get_current_value_type()

    with ops.op_scope(dist_args.values(), name, "DistributionTensor") as scope:
      self._name = scope
      self._dist = dist_cls(**dist_args)
      self._value = self._create_value()

    super(DistributionTensor, self).__init__()

  @property
  def input_dict(self):
    return self._dist_args

  @property
  def distribution(self):
    return self._dist

  def clone(self, name=None, **dist_args):
    return DistributionTensor(self._dist_cls, name=name, **dist_args)

  def _create_value(self):
    """Create the value Tensor based on the value type, store as self._value."""

    if isinstance(self._value_type, MeanValue):
      value_tensor = self._dist.mean()
    elif isinstance(self._value_type, SampleValue):
      value_tensor = self._dist.sample(self._value_type.n)
    elif isinstance(self._value_type, SampleAndReshapeValue):
      if self._value_type.n == 1:
        value_tensor = array_ops.squeeze(self._dist.sample(1), [0])
      else:
        samples = self._dist.sample(self._value_type.n)
        samples_shape = array_ops.shape(samples)
        samples_static_shape = samples.get_shape()
        new_batch_size = samples_shape[0] * samples_shape[1]
        value_tensor = array_ops.reshape(
            samples, array_ops.concat(0, ([new_batch_size], samples_shape[2:])))
        if samples_static_shape.ndims is not None:
          # Update the static shape for shape inference purposes
          shape_list = samples_static_shape.as_list()
          new_shape = tensor_shape.vector(
              shape_list[0] * shape_list[1]
              if shape_list[0] is not None and shape_list[1] is not None
              else None)
          new_shape = new_shape.concatenate(samples_static_shape[2:])
          value_tensor.set_shape(new_shape)
    else:
      raise TypeError(
          "Unrecognized Distribution Value Type: %s", self._value_type)

    stop_gradient = self._value_type.stop_gradient

    if stop_gradient:
      # stop_gradient is being enforced by the value type
      return array_ops.stop_gradient(value_tensor)

    if isinstance(self._value_type, MeanValue):
      return value_tensor  # Using pathwise-derivative for this one.
    if (isinstance(self._dist, distributions.ContinuousDistribution)
        and self._dist.is_reparameterized):
      return value_tensor  # Using pathwise-derivative for this one.
    else:
      # Will have to perform some variant of score function
      # estimation.  Call stop_gradient on the sampler just in case we
      # may accidentally leak some gradient from it.
      return array_ops.stop_gradient(value_tensor)

  @property
  def name(self):
    return self._name

  @property
  def graph(self):
    return self._value.graph

  @property
  def dtype(self):
    return self._dist.dtype

  def entropy(self, name="entropy"):
    return self._dist.entropy(name=name)

  def mean(self, name="mean"):
    return self._dist.mean(name=name)

  def value(self, name="value"):
    return self._value

  def surrogate_loss(self, losses, name=None):
    # Return a loss term based on losses and the distribution.  Return
    # None if pathwise derivatives are supported
    if (isinstance(self._dist, distributions.ContinuousDistribution)
        and self._dist.is_reparameterized):
      # Can perform pathwise-derivative on this one; no surrogate loss needed.
      return None

    with ops.op_scope(losses, name, "DistributionSurrogateLoss"):
      if isinstance(self._value_type, SampleAndReshapeValue):
        # TODO(ebrevdo): use add_n instead of sum(losses) if shapes all match?
        return self._dist.log_likelihood(self._value) * sum(losses)
      elif isinstance(self._value_type, SampleValue):
        return self._dist.log_likelihood(self._value) * sum(losses)
      elif isinstance(self._value_type, MeanValue):
        return None  # MeanValue generally provides its own gradient
      else:
        raise TypeError(
            "Unrecognized Distribution Value Type: %s", self._value_type)


def _stochastic_dependencies_map(fixed_losses):
  """Map stochastic tensors to the fixed losses that depend on them.

  Args:
    fixed_losses: a list of Tensors.

  Returns:
    A dict `dependencies` that maps `StochasticTensor` objects to subsets of
    `fixed_losses`.

    If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
    is a direct path from `st.value()` to `loss` in the graph.
  """
  stoch_value_collection = ops.get_collection(
      STOCHASTIC_TENSOR_COLLECTION)

  if not stoch_value_collection:
    return {}

  stoch_value_map = dict(
      (node.value(), node) for node in stoch_value_collection)

  # Step backwards through the graph to see which surrogate losses correspond
  # to which fixed_losses.
  stoch_dependencies_map = collections.defaultdict(set)
  for loss in fixed_losses:
    boundary = set([loss])
    while boundary:
      edge = boundary.pop()
      edge_stoch_node = stoch_value_map.get(edge, None)
      if edge_stoch_node:
        stoch_dependencies_map[edge_stoch_node].add(loss)
      boundary.update(edge.op.inputs)

  return stoch_dependencies_map


def surrogate_losses(sample_losses, name=None):
  with ops.op_scope(sample_losses, name, "SampleLosses"):
    fixed_losses = []
    if not isinstance(sample_losses, (list, tuple)):
      raise TypeError("sample_losses must be a list or tuple")
    for loss in sample_losses:
      if not isinstance(loss, ops.Tensor):
        raise TypeError("loss is not a Tensor: %s" % loss)
      ndims = loss.get_shape().ndims
      if not (ndims is not None and ndims <= 1):
        raise ValueError(
            "loss must be a scalar or batch-length vector loss: %s" % loss)
      fixed_losses.append(array_ops.stop_gradient(loss))

    stoch_dependencies_map = _stochastic_dependencies_map(fixed_losses)
    if not stoch_dependencies_map:
      logging.warn(
          "No collection of Stochastic Tensors found for current graph.")
      return []

    surrogate_loss_losses = []

    # Iterate through all of the stochastic dependencies, adding
    # surrogate terms where necessary.
    for (stoch_node, dependent_losses) in stoch_dependencies_map.items():
      surrogate_loss = stoch_node.surrogate_loss(list(dependent_losses))
      if surrogate_loss is not None:
        with ops.name_scope("SurrogateLoss_%s" % stoch_node.name):
          surrogate_loss_losses.append(array_ops.identity(surrogate_loss))

    return surrogate_loss_losses