aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/distribution.py
blob: c03ef967e68474b0313de01d48252c8274e37a21 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base classes for probability distributions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import contextlib
import types

import numpy as np
import six

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export


__all__ = [
    "ReparameterizationType",
    "FULLY_REPARAMETERIZED",
    "NOT_REPARAMETERIZED",
    "Distribution",
]

_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
    "batch_shape",
    "batch_shape_tensor",
    "cdf",
    "covariance",
    "cross_entropy",
    "entropy",
    "event_shape",
    "event_shape_tensor",
    "kl_divergence",
    "log_cdf",
    "log_prob",
    "log_survival_function",
    "mean",
    "mode",
    "prob",
    "sample",
    "stddev",
    "survival_function",
    "variance",
]


@six.add_metaclass(abc.ABCMeta)
class _BaseDistribution(object):
  """Abstract base class needed for resolving subclass hierarchy."""
  pass


def _copy_fn(fn):
  """Create a deep copy of fn.

  Args:
    fn: a callable

  Returns:
    A `FunctionType`: a deep copy of fn.

  Raises:
    TypeError: if `fn` is not a callable.
  """
  if not callable(fn):
    raise TypeError("fn is not callable: %s" % fn)
  # The blessed way to copy a function. copy.deepcopy fails to create a
  # non-reference copy. Since:
  #   types.FunctionType == type(lambda: None),
  # and the docstring for the function type states:
  #
  #   function(code, globals[, name[, argdefs[, closure]]])
  #
  #   Create a function object from a code object and a dictionary.
  #   ...
  #
  # Here we can use this to create a new function with the old function's
  # code, globals, closure, etc.
  return types.FunctionType(
      code=fn.__code__, globals=fn.__globals__,
      name=fn.__name__, argdefs=fn.__defaults__,
      closure=fn.__closure__)


def _update_docstring(old_str, append_str):
  """Update old_str by inserting append_str just before the "Args:" section."""
  old_str = old_str or ""
  old_str_lines = old_str.split("\n")

  # Step 0: Prepend spaces to all lines of append_str. This is
  # necessary for correct markdown generation.
  append_str = "\n".join("    %s" % line for line in append_str.split("\n"))

  # Step 1: Find mention of "Args":
  has_args_ix = [
      ix for ix, line in enumerate(old_str_lines)
      if line.strip().lower() == "args:"]
  if has_args_ix:
    final_args_ix = has_args_ix[-1]
    return ("\n".join(old_str_lines[:final_args_ix])
            + "\n\n" + append_str + "\n\n"
            + "\n".join(old_str_lines[final_args_ix:]))
  else:
    return old_str + "\n\n" + append_str


class _DistributionMeta(abc.ABCMeta):

  def __new__(mcs, classname, baseclasses, attrs):
    """Control the creation of subclasses of the Distribution class.

    The main purpose of this method is to properly propagate docstrings
    from private Distribution methods, like `_log_prob`, into their
    public wrappers as inherited by the Distribution base class
    (e.g. `log_prob`).

    Args:
      classname: The name of the subclass being created.
      baseclasses: A tuple of parent classes.
      attrs: A dict mapping new attributes to their values.

    Returns:
      The class object.

    Raises:
      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
        the new class is derived via multiple inheritance and the first
        parent class is not a subclass of `BaseDistribution`.
      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
      ValueError:  If a `Distribution` public method lacks a docstring.
    """
    if not baseclasses:  # Nothing to be done for Distribution
      raise TypeError("Expected non-empty baseclass. Does Distribution "
                      "not subclass _BaseDistribution?")
    which_base = [
        base for base in baseclasses
        if base == _BaseDistribution or issubclass(base, Distribution)]
    base = which_base[0]
    if base == _BaseDistribution:  # Nothing to be done for Distribution
      return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
    if not issubclass(base, Distribution):
      raise TypeError("First parent class declared for %s must be "
                      "Distribution, but saw '%s'" % (classname, base.__name__))
    for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
      special_attr = "_%s" % attr
      class_attr_value = attrs.get(attr, None)
      if attr in attrs:
        # The method is being overridden, do not update its docstring
        continue
      base_attr_value = getattr(base, attr, None)
      if not base_attr_value:
        raise AttributeError(
            "Internal error: expected base class '%s' to implement method '%s'"
            % (base.__name__, attr))
      class_special_attr_value = attrs.get(special_attr, None)
      if class_special_attr_value is None:
        # No _special method available, no need to update the docstring.
        continue
      class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
      if not class_special_attr_docstring:
        # No docstring to append.
        continue
      class_attr_value = _copy_fn(base_attr_value)
      class_attr_docstring = tf_inspect.getdoc(base_attr_value)
      if class_attr_docstring is None:
        raise ValueError(
            "Expected base class fn to contain a docstring: %s.%s"
            % (base.__name__, attr))
      class_attr_value.__doc__ = _update_docstring(
          class_attr_value.__doc__,
          ("Additional documentation from `%s`:\n\n%s"
           % (classname, class_special_attr_docstring)))
      attrs[attr] = class_attr_value

    return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)


@tf_export("distributions.ReparameterizationType")
class ReparameterizationType(object):
  """Instances of this class represent how sampling is reparameterized.

  Two static instances exist in the distributions library, signifying
  one of two possible properties for samples from a distribution:

  `FULLY_REPARAMETERIZED`: Samples from the distribution are fully
    reparameterized, and straight-through gradients are supported.

  `NOT_REPARAMETERIZED`: Samples from the distribution are not fully
    reparameterized, and straight-through gradients are either partially
    unsupported or are not supported at all. In this case, for purposes of
    e.g. RL or variational inference, it is generally safest to wrap the
    sample results in a `stop_gradients` call and use policy
    gradients / surrogate loss instead.
  """

  def __init__(self, rep_type):
    self._rep_type = rep_type

  def __repr__(self):
    return "<Reparameteriation Type: %s>" % self._rep_type

  def __eq__(self, other):
    """Determine if this `ReparameterizationType` is equal to another.

    Since RepaparameterizationType instances are constant static global
    instances, equality checks if two instances' id() values are equal.

    Args:
      other: Object to compare against.

    Returns:
      `self is other`.
    """
    return self is other


# Fully reparameterized distribution: samples from a fully
# reparameterized distribution support straight-through gradients with
# respect to all parameters.
FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
tf_export("distributions.FULLY_REPARAMETERIZED").export_constant(
    __name__, "FULLY_REPARAMETERIZED")


# Not reparameterized distribution: samples from a non-
# reparameterized distribution do not support straight-through gradients for
# at least some of the parameters.
NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
tf_export("distributions.NOT_REPARAMETERIZED").export_constant(
    __name__, "NOT_REPARAMETERIZED")


@six.add_metaclass(_DistributionMeta)
@tf_export("distributions.Distribution")
class Distribution(_BaseDistribution):
  """A generic probability distribution base class.

  `Distribution` is a base class for constructing and organizing properties
  (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).

  #### Subclassing

  Subclasses are expected to implement a leading-underscore version of the
  same-named function. The argument signature should be identical except for
  the omission of `name="..."`. For example, to enable `log_prob(value,
  name="log_prob")` a subclass should implement `_log_prob(value)`.

  Subclasses can append to public-level docstrings by providing
  docstrings for their method specializations. For example:

  ```python
  @util.AppendDocstring("Some other details.")
  def _log_prob(self, value):
    ...
  ```

  would add the string "Some other details." to the `log_prob` function
  docstring. This is implemented as a simple decorator to avoid python
  linter complaining about missing Args/Returns/Raises sections in the
  partial docstrings.

  #### Broadcasting, batching, and shapes

  All distributions support batches of independent distributions of that type.
  The batch shape is determined by broadcasting together the parameters.

  The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
  `log_prob` reflect this broadcasting, as does the return value of `sample` and
  `sample_n`.

  `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
  the shape of the `Tensor` returned from `sample_n`, `n` is the number of
  samples, `batch_shape` defines how many independent distributions there are,
  and `event_shape` defines the shape of samples from each of those independent
  distributions. Samples are independent along the `batch_shape` dimensions, but
  not necessarily so along the `event_shape` dimensions (depending on the
  particulars of the underlying distribution).

  Using the `Uniform` distribution as an example:

  ```python
  minval = 3.0
  maxval = [[4.0, 6.0],
            [10.0, 12.0]]

  # Broadcasting:
  # This instance represents 4 Uniform distributions. Each has a lower bound at
  # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
  u = Uniform(minval, maxval)

  # `event_shape` is `TensorShape([])`.
  event_shape = u.event_shape
  # `event_shape_t` is a `Tensor` which will evaluate to [].
  event_shape_t = u.event_shape_tensor()

  # Sampling returns a sample per distribution. `samples` has shape
  # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
  # batch_shape=[2, 2], and event_shape=[].
  samples = u.sample_n(5)

  # The broadcasting holds across methods. Here we use `cdf` as an example. The
  # same holds for `log_cdf` and the likelihood functions.

  # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
  # shape of the `Uniform` instance.
  cum_prob_broadcast = u.cdf(4.0)

  # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
  # occurred.
  cum_prob_per_dist = u.cdf([[4.0, 5.0],
                             [6.0, 7.0]])

  # INVALID as the `value` argument is not broadcastable to the distribution's
  # shape.
  cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
  ```

  #### Shapes

  There are three important concepts associated with TensorFlow Distributions
  shapes:
  - Event shape describes the shape of a single draw from the distribution;
    it may be dependent across dimensions. For scalar distributions, the event
    shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is
    `[5]`.
  - Batch shape describes independent, not identically distributed draws, aka a
    "collection" or "bunch" of distributions.
  - Sample shape describes independent, identically distributed draws of batches
    from the distribution family.

  The event shape and the batch shape are properties of a Distribution object,
  whereas the sample shape is associated with a specific call to `sample` or
  `log_prob`.

  For detailed usage examples of TensorFlow Distributions shapes, see
  [this tutorial](
  https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)

  #### Parameter values leading to undefined statistics or distributions.

  Some distributions do not have well-defined statistics for all initialization
  parameter values. For example, the beta distribution is parameterized by
  positive real numbers `concentration1` and `concentration0`, and does not have
  well-defined mode if `concentration1 < 1` or `concentration0 < 1`.

  The user is given the option of raising an exception or returning `NaN`.

  ```python
  a = tf.exp(tf.matmul(logits, weights_a))
  b = tf.exp(tf.matmul(logits, weights_b))

  # Will raise exception if ANY batch member has a < 1 or b < 1.
  dist = distributions.beta(a, b, allow_nan_stats=False)
  mode = dist.mode().eval()

  # Will return NaN for batch members with either a < 1 or b < 1.
  dist = distributions.beta(a, b, allow_nan_stats=True)  # Default behavior
  mode = dist.mode().eval()
  ```

  In all cases, an exception is raised if *invalid* parameters are passed, e.g.

  ```python
  # Will raise an exception if any Op is run.
  negative_a = -1.0 * a  # beta distribution by definition has a > 0.
  dist = distributions.beta(negative_a, b, allow_nan_stats=True)
  dist.mean().eval()
  ```

  """

  def __init__(self,
               dtype,
               reparameterization_type,
               validate_args,
               allow_nan_stats,
               parameters=None,
               graph_parents=None,
               name=None):
    """Constructs the `Distribution`.

    **This is a private method for subclass use.**

    Args:
      dtype: The type of the event samples. `None` implies no type-enforcement.
      reparameterization_type: Instance of `ReparameterizationType`.
        If `distributions.FULLY_REPARAMETERIZED`, this
        `Distribution` can be reparameterized in terms of some standard
        distribution with a function whose Jacobian is constant for the support
        of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
        then no such reparameterization is available.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      parameters: Python `dict` of parameters used to instantiate this
        `Distribution`.
      graph_parents: Python `list` of graph prerequisites of this
        `Distribution`.
      name: Python `str` name prefixed to Ops created by this class. Default:
        subclass name.

    Raises:
      ValueError: if any member of graph_parents is `None` or not a `Tensor`.
    """
    graph_parents = [] if graph_parents is None else graph_parents
    for i, t in enumerate(graph_parents):
      if t is None or not tensor_util.is_tensor(t):
        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
    if not name or name[-1] != "/":  # `name` is not a name scope
      non_unique_name = name or type(self).__name__
      with ops.name_scope(non_unique_name) as name:
        pass
    self._dtype = dtype
    self._reparameterization_type = reparameterization_type
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    self._parameters = parameters or {}
    self._graph_parents = graph_parents
    self._name = name

  @classmethod
  def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
    """Shapes of parameters given the desired shape of a call to `sample()`.

    This is a class method that describes what key/value arguments are required
    to instantiate the given `Distribution` so that a particular shape is
    returned for that instance's call to `sample()`.

    Subclasses should override class method `_param_shapes`.

    Args:
      sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
        `sample()`.
      name: name to prepend ops with.

    Returns:
      `dict` of parameter name to `Tensor` shapes.
    """
    with ops.name_scope(name, values=[sample_shape]):
      return cls._param_shapes(sample_shape)

  @classmethod
  def param_static_shapes(cls, sample_shape):
    """param_shapes with static (i.e. `TensorShape`) shapes.

    This is a class method that describes what key/value arguments are required
    to instantiate the given `Distribution` so that a particular shape is
    returned for that instance's call to `sample()`. Assumes that the sample's
    shape is known statically.

    Subclasses should override class method `_param_shapes` to return
    constant-valued tensors when constant values are fed.

    Args:
      sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
        to `sample()`.

    Returns:
      `dict` of parameter name to `TensorShape`.

    Raises:
      ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
    """
    if isinstance(sample_shape, tensor_shape.TensorShape):
      if not sample_shape.is_fully_defined():
        raise ValueError("TensorShape sample_shape must be fully defined")
      sample_shape = sample_shape.as_list()

    params = cls.param_shapes(sample_shape)

    static_params = {}
    for name, shape in params.items():
      static_shape = tensor_util.constant_value(shape)
      if static_shape is None:
        raise ValueError(
            "sample_shape must be a fully-defined TensorShape or list/tuple")
      static_params[name] = tensor_shape.TensorShape(static_shape)

    return static_params

  @staticmethod
  def _param_shapes(sample_shape):
    raise NotImplementedError("_param_shapes not implemented")

  @property
  def name(self):
    """Name prepended to all ops created by this `Distribution`."""
    return self._name

  @property
  def dtype(self):
    """The `DType` of `Tensor`s handled by this `Distribution`."""
    return self._dtype

  @property
  def parameters(self):
    """Dictionary of parameters used to instantiate this `Distribution`."""
    # Remove "self", "__class__", or other special variables. These can appear
    # if the subclass used:
    # `parameters = dict(locals())`.
    return dict((k, v) for k, v in self._parameters.items()
                if not k.startswith("__") and k != "self")

  @property
  def reparameterization_type(self):
    """Describes how samples from the distribution are reparameterized.

    Currently this is one of the static instances
    `distributions.FULLY_REPARAMETERIZED`
    or `distributions.NOT_REPARAMETERIZED`.

    Returns:
      An instance of `ReparameterizationType`.
    """
    return self._reparameterization_type

  @property
  def allow_nan_stats(self):
    """Python `bool` describing behavior when a stat is undefined.

    Stats return +/- infinity when it makes sense. E.g., the variance of a
    Cauchy distribution is infinity. However, sometimes the statistic is
    undefined, e.g., if a distribution's pdf does not achieve a maximum within
    the support of the distribution, the mode is undefined. If the mean is
    undefined, then by definition the variance is undefined. E.g. the mean for
    Student's T for df = 1 is undefined (no clear way to say it is either + or -
    infinity), so the variance = E[(X - mean)**2] is also undefined.

    Returns:
      allow_nan_stats: Python `bool`.
    """
    return self._allow_nan_stats

  @property
  def validate_args(self):
    """Python `bool` indicating possibly expensive checks are enabled."""
    return self._validate_args

  def copy(self, **override_parameters_kwargs):
    """Creates a deep copy of the distribution.

    Note: the copy distribution may continue to depend on the original
    initialization arguments.

    Args:
      **override_parameters_kwargs: String/value dictionary of initialization
        arguments to override with new values.

    Returns:
      distribution: A new instance of `type(self)` initialized from the union
        of self.parameters and override_parameters_kwargs, i.e.,
        `dict(self.parameters, **override_parameters_kwargs)`.
    """
    parameters = dict(self.parameters, **override_parameters_kwargs)
    return type(self)(**parameters)

  def _batch_shape_tensor(self):
    raise NotImplementedError("batch_shape_tensor is not implemented")

  def batch_shape_tensor(self, name="batch_shape_tensor"):
    """Shape of a single sample from a single event index as a 1-D `Tensor`.

    The batch dimensions are indexes into independent, non-identical
    parameterizations of this distribution.

    Args:
      name: name to give to the op

    Returns:
      batch_shape: `Tensor`.
    """
    with self._name_scope(name):
      if self.batch_shape.is_fully_defined():
        return ops.convert_to_tensor(self.batch_shape.as_list(),
                                     dtype=dtypes.int32,
                                     name="batch_shape")
      return self._batch_shape_tensor()

  def _batch_shape(self):
    return tensor_shape.TensorShape(None)

  @property
  def batch_shape(self):
    """Shape of a single sample from a single event index as a `TensorShape`.

    May be partially defined or unknown.

    The batch dimensions are indexes into independent, non-identical
    parameterizations of this distribution.

    Returns:
      batch_shape: `TensorShape`, possibly unknown.
    """
    return tensor_shape.as_shape(self._batch_shape())

  def _event_shape_tensor(self):
    raise NotImplementedError("event_shape_tensor is not implemented")

  def event_shape_tensor(self, name="event_shape_tensor"):
    """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.

    Args:
      name: name to give to the op

    Returns:
      event_shape: `Tensor`.
    """
    with self._name_scope(name):
      if self.event_shape.is_fully_defined():
        return ops.convert_to_tensor(self.event_shape.as_list(),
                                     dtype=dtypes.int32,
                                     name="event_shape")
      return self._event_shape_tensor()

  def _event_shape(self):
    return tensor_shape.TensorShape(None)

  @property
  def event_shape(self):
    """Shape of a single sample from a single batch as a `TensorShape`.

    May be partially defined or unknown.

    Returns:
      event_shape: `TensorShape`, possibly unknown.
    """
    return tensor_shape.as_shape(self._event_shape())

  def is_scalar_event(self, name="is_scalar_event"):
    """Indicates that `event_shape == []`.

    Args:
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      is_scalar_event: `bool` scalar `Tensor`.
    """
    with self._name_scope(name):
      return ops.convert_to_tensor(
          self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
          name="is_scalar_event")

  def is_scalar_batch(self, name="is_scalar_batch"):
    """Indicates that `batch_shape == []`.

    Args:
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      is_scalar_batch: `bool` scalar `Tensor`.
    """
    with self._name_scope(name):
      return ops.convert_to_tensor(
          self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
          name="is_scalar_batch")

  def _sample_n(self, n, seed=None):
    raise NotImplementedError("sample_n is not implemented")

  def _call_sample_n(self, sample_shape, seed, name, **kwargs):
    with self._name_scope(name, values=[sample_shape]):
      sample_shape = ops.convert_to_tensor(
          sample_shape, dtype=dtypes.int32, name="sample_shape")
      sample_shape, n = self._expand_sample_shape_to_vector(
          sample_shape, "sample_shape")
      samples = self._sample_n(n, seed, **kwargs)
      batch_event_shape = array_ops.shape(samples)[1:]
      final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
      samples = array_ops.reshape(samples, final_shape)
      samples = self._set_sample_static_shape(samples, sample_shape)
      return samples

  def sample(self, sample_shape=(), seed=None, name="sample"):
    """Generate samples of the specified shape.

    Note that a call to `sample()` without arguments will generate a single
    sample.

    Args:
      sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
      seed: Python integer seed for RNG
      name: name to give to the op.

    Returns:
      samples: a `Tensor` with prepended dimensions `sample_shape`.
    """
    return self._call_sample_n(sample_shape, seed, name)

  def _log_prob(self, value):
    raise NotImplementedError("log_prob is not implemented")

  def _call_log_prob(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      try:
        return self._log_prob(value, **kwargs)
      except NotImplementedError:
        return math_ops.log(self._prob(value, **kwargs))

  def log_prob(self, value, name="log_prob"):
    """Log probability density/mass function.

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
        values of type `self.dtype`.
    """
    return self._call_log_prob(value, name)

  def _prob(self, value):
    raise NotImplementedError("prob is not implemented")

  def _call_prob(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      try:
        return self._prob(value, **kwargs)
      except NotImplementedError:
        return math_ops.exp(self._log_prob(value, **kwargs))

  def prob(self, value, name="prob"):
    """Probability density/mass function.

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
        values of type `self.dtype`.
    """
    return self._call_prob(value, name)

  def _log_cdf(self, value):
    raise NotImplementedError("log_cdf is not implemented")

  def _call_log_cdf(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      try:
        return self._log_cdf(value, **kwargs)
      except NotImplementedError:
        return math_ops.log(self._cdf(value, **kwargs))

  def log_cdf(self, value, name="log_cdf"):
    """Log cumulative distribution function.

    Given random variable `X`, the cumulative distribution function `cdf` is:

    ```none
    log_cdf(x) := Log[ P[X <= x] ]
    ```

    Often, a numerical approximation can be used for `log_cdf(x)` that yields
    a more accurate answer than simply taking the logarithm of the `cdf` when
    `x << -1`.

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
        values of type `self.dtype`.
    """
    return self._call_log_cdf(value, name)

  def _cdf(self, value):
    raise NotImplementedError("cdf is not implemented")

  def _call_cdf(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      try:
        return self._cdf(value, **kwargs)
      except NotImplementedError:
        return math_ops.exp(self._log_cdf(value, **kwargs))

  def cdf(self, value, name="cdf"):
    """Cumulative distribution function.

    Given random variable `X`, the cumulative distribution function `cdf` is:

    ```none
    cdf(x) := P[X <= x]
    ```

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
        values of type `self.dtype`.
    """
    return self._call_cdf(value, name)

  def _log_survival_function(self, value):
    raise NotImplementedError("log_survival_function is not implemented")

  def _call_log_survival_function(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      try:
        return self._log_survival_function(value, **kwargs)
      except NotImplementedError:
        return math_ops.log1p(-self.cdf(value, **kwargs))

  def log_survival_function(self, value, name="log_survival_function"):
    """Log survival function.

    Given random variable `X`, the survival function is defined:

    ```none
    log_survival_function(x) = Log[ P[X > x] ]
                             = Log[ 1 - P[X <= x] ]
                             = Log[ 1 - cdf(x) ]
    ```

    Typically, different numerical approximations can be used for the log
    survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
        `self.dtype`.
    """
    return self._call_log_survival_function(value, name)

  def _survival_function(self, value):
    raise NotImplementedError("survival_function is not implemented")

  def _call_survival_function(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      try:
        return self._survival_function(value, **kwargs)
      except NotImplementedError:
        return 1. - self.cdf(value, **kwargs)

  def survival_function(self, value, name="survival_function"):
    """Survival function.

    Given random variable `X`, the survival function is defined:

    ```none
    survival_function(x) = P[X > x]
                         = 1 - P[X <= x]
                         = 1 - cdf(x).
    ```

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
        `self.dtype`.
    """
    return self._call_survival_function(value, name)

  def _entropy(self):
    raise NotImplementedError("entropy is not implemented")

  def entropy(self, name="entropy"):
    """Shannon entropy in nats."""
    with self._name_scope(name):
      return self._entropy()

  def _mean(self):
    raise NotImplementedError("mean is not implemented")

  def mean(self, name="mean"):
    """Mean."""
    with self._name_scope(name):
      return self._mean()

  def _quantile(self, value):
    raise NotImplementedError("quantile is not implemented")

  def _call_quantile(self, value, name, **kwargs):
    with self._name_scope(name, values=[value]):
      value = ops.convert_to_tensor(value, name="value")
      return self._quantile(value, **kwargs)

  def quantile(self, value, name="quantile"):
    """Quantile function. Aka "inverse cdf" or "percent point function".

    Given random variable `X` and `p in [0, 1]`, the `quantile` is:

    ```none
    quantile(p) := x such that P[X <= x] == p
    ```

    Args:
      value: `float` or `double` `Tensor`.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
        values of type `self.dtype`.
    """
    return self._call_quantile(value, name)

  def _variance(self):
    raise NotImplementedError("variance is not implemented")

  def variance(self, name="variance"):
    """Variance.

    Variance is defined as,

    ```none
    Var = E[(X - E[X])**2]
    ```

    where `X` is the random variable associated with this distribution, `E`
    denotes expectation, and `Var.shape = batch_shape + event_shape`.

    Args:
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      variance: Floating-point `Tensor` with shape identical to
        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
    """
    with self._name_scope(name):
      try:
        return self._variance()
      except NotImplementedError:
        return math_ops.square(self._stddev())

  def _stddev(self):
    raise NotImplementedError("stddev is not implemented")

  def stddev(self, name="stddev"):
    """Standard deviation.

    Standard deviation is defined as,

    ```none
    stddev = E[(X - E[X])**2]**0.5
    ```

    where `X` is the random variable associated with this distribution, `E`
    denotes expectation, and `stddev.shape = batch_shape + event_shape`.

    Args:
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      stddev: Floating-point `Tensor` with shape identical to
        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
    """

    with self._name_scope(name):
      try:
        return self._stddev()
      except NotImplementedError:
        return math_ops.sqrt(self._variance())

  def _covariance(self):
    raise NotImplementedError("covariance is not implemented")

  def covariance(self, name="covariance"):
    """Covariance.

    Covariance is (possibly) defined only for non-scalar-event distributions.

    For example, for a length-`k`, vector-valued distribution, it is calculated
    as,

    ```none
    Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
    ```

    where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
    denotes expectation.

    Alternatively, for non-vector, multivariate distributions (e.g.,
    matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
    under some vectorization of the events, i.e.,

    ```none
    Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
    ```

    where `Cov` is a (batch of) `k' x k'` matrices,
    `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
    mapping indices of this distribution's event dimensions to indices of a
    length-`k'` vector.

    Args:
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
        where the first `n` dimensions are batch coordinates and
        `k' = reduce_prod(self.event_shape)`.
    """
    with self._name_scope(name):
      return self._covariance()

  def _mode(self):
    raise NotImplementedError("mode is not implemented")

  def mode(self, name="mode"):
    """Mode."""
    with self._name_scope(name):
      return self._mode()

  def _cross_entropy(self, other):
    return kullback_leibler.cross_entropy(
        self, other, allow_nan_stats=self.allow_nan_stats)

  def cross_entropy(self, other, name="cross_entropy"):
    """Computes the (Shannon) cross entropy.

    Denote this distribution (`self`) by `P` and the `other` distribution by
    `Q`. Assuming `P, Q` are absolutely continuous with respect to
    one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
    cross entropy is defined as:

    ```none
    H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
    ```

    where `F` denotes the support of the random variable `X ~ P`.

    Args:
      other: `tf.distributions.Distribution` instance.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
        representing `n` different calculations of (Shanon) cross entropy.
    """
    with self._name_scope(name):
      return self._cross_entropy(other)

  def _kl_divergence(self, other):
    return kullback_leibler.kl_divergence(
        self, other, allow_nan_stats=self.allow_nan_stats)

  def kl_divergence(self, other, name="kl_divergence"):
    """Computes the Kullback--Leibler divergence.

    Denote this distribution (`self`) by `p` and the `other` distribution by
    `q`. Assuming `p, q` are absolutely continuous with respect to reference
    measure `r`, the KL divergence is defined as:

    ```none
    KL[p, q] = E_p[log(p(X)/q(X))]
             = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
             = H[p, q] - H[p]
    ```

    where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
    denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.

    Args:
      other: `tf.distributions.Distribution` instance.
      name: Python `str` prepended to names of ops created by this function.

    Returns:
      kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
        representing `n` different calculations of the Kullback-Leibler
        divergence.
    """
    with self._name_scope(name):
      return self._kl_divergence(other)

  def __str__(self):
    return ("tf.distributions.{type_name}("
            "\"{self_name}\""
            "{maybe_batch_shape}"
            "{maybe_event_shape}"
            ", dtype={dtype})".format(
                type_name=type(self).__name__,
                self_name=self.name,
                maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
                                   if self.batch_shape.ndims is not None
                                   else ""),
                maybe_event_shape=(", event_shape={}".format(self.event_shape)
                                   if self.event_shape.ndims is not None
                                   else ""),
                dtype=self.dtype.name))

  def __repr__(self):
    return ("<tf.distributions.{type_name} "
            "'{self_name}'"
            " batch_shape={batch_shape}"
            " event_shape={event_shape}"
            " dtype={dtype}>".format(
                type_name=type(self).__name__,
                self_name=self.name,
                batch_shape=self.batch_shape,
                event_shape=self.event_shape,
                dtype=self.dtype.name))

  @contextlib.contextmanager
  def _name_scope(self, name=None, values=None):
    """Helper function to standardize op scope."""
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=(
          ([] if values is None else values) + self._graph_parents)) as scope:
        yield scope

  def _expand_sample_shape_to_vector(self, x, name):
    """Helper to `sample` which ensures input is 1D."""
    x_static_val = tensor_util.constant_value(x)
    if x_static_val is None:
      prod = math_ops.reduce_prod(x)
    else:
      prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())

    ndims = x.get_shape().ndims  # != sample_ndims
    if ndims is None:
      # Maybe expand_dims.
      ndims = array_ops.rank(x)
      expanded_shape = util.pick_vector(
          math_ops.equal(ndims, 0),
          np.array([1], dtype=np.int32), array_ops.shape(x))
      x = array_ops.reshape(x, expanded_shape)
    elif ndims == 0:
      # Definitely expand_dims.
      if x_static_val is not None:
        x = ops.convert_to_tensor(
            np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
            name=name)
      else:
        x = array_ops.reshape(x, [1])
    elif ndims != 1:
      raise ValueError("Input is neither scalar nor vector.")

    return x, prod

  def _set_sample_static_shape(self, x, sample_shape):
    """Helper to `sample`; sets static shape info."""
    # Set shape hints.
    sample_shape = tensor_shape.TensorShape(
        tensor_util.constant_value(sample_shape))

    ndims = x.get_shape().ndims
    sample_ndims = sample_shape.ndims
    batch_ndims = self.batch_shape.ndims
    event_ndims = self.event_shape.ndims

    # Infer rank(x).
    if (ndims is None and
        sample_ndims is not None and
        batch_ndims is not None and
        event_ndims is not None):
      ndims = sample_ndims + batch_ndims + event_ndims
      x.set_shape([None] * ndims)

    # Infer sample shape.
    if ndims is not None and sample_ndims is not None:
      shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
      x.set_shape(x.get_shape().merge_with(shape))

    # Infer event shape.
    if ndims is not None and event_ndims is not None:
      shape = tensor_shape.TensorShape(
          [None]*(ndims - event_ndims)).concatenate(self.event_shape)
      x.set_shape(x.get_shape().merge_with(shape))

    # Infer batch shape.
    if batch_ndims is not None:
      if ndims is not None:
        if sample_ndims is None and event_ndims is not None:
          sample_ndims = ndims - batch_ndims - event_ndims
        elif event_ndims is None and sample_ndims is not None:
          event_ndims = ndims - batch_ndims - sample_ndims
      if sample_ndims is not None and event_ndims is not None:
        shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
            self.batch_shape).concatenate([None]*event_ndims)
        x.set_shape(x.get_shape().merge_with(shape))

    return x

  def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
    """Implementation for `is_scalar_batch` and `is_scalar_event`."""
    if static_shape.ndims is not None:
      return static_shape.ndims == 0
    shape = dynamic_shape_fn()
    if (shape.get_shape().ndims is not None and
        shape.get_shape()[0].value is not None):
      # If the static_shape is correctly written then we should never execute
      # this branch. We keep it just in case there's some unimagined corner
      # case.
      return shape.get_shape().as_list() == [0]
    return math_ops.equal(array_ops.shape(shape)[0], 0)