aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/random_ops.py
blob: 6bd8dd9e3d25c40c6da4ff2dfff7434a32d9d5f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Operations for generating random numbers."""

from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import types
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_random_ops import *
# pylint: enable=wildcard-import


def _ShapeTensor(shape):
  """Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
  if isinstance(shape, (tuple, list)) and not shape:
    dtype = types.int32
  else:
    dtype = None
  return ops.convert_to_tensor(shape, dtype=dtype, name="shape")

# pylint: disable=protected-access
def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
                  seed=None, name=None):
  """Outputs random values from a normal distribution.

  Args:
    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
      distribution.
    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
      of the normal distribution.
    dtype: The type of the output.
    seed: A Python integer. Used to create a random seed for the distribution.
      See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
    name: A name for the operation (optional).

  Returns:
    A tensor of the specified shape filled with random normal values.
  """
  with ops.op_scope([shape, mean, stddev], name, "random_normal") as name:
    shape_tensor = _ShapeTensor(shape)
    mean_tensor = ops.convert_to_tensor(
        mean, dtype=dtype, name="mean")
    stddev_tensor = ops.convert_to_tensor(
        stddev, dtype=dtype, name="stddev")
    seed1, seed2 = random_seed.get_seed(seed)
    rnd = gen_random_ops._random_standard_normal(shape_tensor, dtype,
                                                 seed=seed1,
                                                 seed2=seed2)
    mul = rnd * stddev_tensor
    value = math_ops.add(mul, mean_tensor, name=name)
    return value


ops.NoGradient("RandomStandardNormal")


def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
                     seed=None, name=None):
  """Outputs random values from a truncated normal distribution.

  The generated values follow a normal distribution with specified mean and
  standard deviation, except that values whose magnitude is more than 2 standard
  deviations from the mean are dropped and re-picked.

  Args:
    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
      truncated normal distribution.
    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
      of the truncated normal distribution.
    dtype: The type of the output.
    seed: A Python integer. Used to create a random seed for the distribution.
      See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
    name: A name for the operation (optional).

  Returns:
    A tensor of the specified shape filled with random truncated normal values.
  """
  with ops.op_scope([shape, mean, stddev], name, "truncated_normal") as name:
    shape_tensor = _ShapeTensor(shape)
    mean_tensor = ops.convert_to_tensor(
        mean, dtype=dtype, name="mean")
    stddev_tensor = ops.convert_to_tensor(
        stddev, dtype=dtype, name="stddev")
    seed1, seed2 = random_seed.get_seed(seed)
    rnd = gen_random_ops._truncated_normal(shape_tensor, dtype,
                                           seed=seed1,
                                           seed2=seed2)
    mul = rnd * stddev_tensor
    value = math_ops.add(mul, mean_tensor, name=name)
    return value


ops.NoGradient("TruncatedNormal")


def random_uniform(shape, minval=0.0, maxval=1.0,
                   dtype=types.float32, seed=None,
                   name=None):
  """Outputs random values from a uniform distribution.

  The generated values follow a uniform distribution in the range
  `[minval, maxval)`. The lower bound `minval` is included in the range, while
  the upper bound `maxval` is excluded.

  Args:
    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
      range of random values to generate.
    maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
      the range of random values to generate.
    dtype: The type of the output.
    seed: A Python integer. Used to create a random seed for the distribution.
      See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
    name: A name for the operation (optional).

  Returns:
    A tensor of the specified shape filled with random uniform values.
  """
  with ops.op_scope([shape, minval, maxval], name, "random_uniform") as name:
    shape_tensor = _ShapeTensor(shape)
    min_tensor = ops.convert_to_tensor(minval, dtype=dtype, name="min")
    range_tensor = ops.convert_to_tensor(
        maxval - minval, dtype=dtype, name="range")
    seed1, seed2 = random_seed.get_seed(seed)
    rnd = gen_random_ops._random_uniform(shape_tensor, dtype,
                                         seed=seed1,
                                         seed2=seed2)
    mul = rnd * range_tensor
    value = math_ops.add(mul, min_tensor, name=name)
    return value


def random_shuffle(value, seed=None, name=None):
  """Randomly shuffles a tensor along its first dimension.

  The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
  to one and only one `output[i]`. For example, a mapping that might occur for a
  3x2 tensor is:

  ```python
  [[1, 2],       [[5, 6],
   [3, 4],  ==>   [1, 2],
   [5, 6]]        [3, 4]]
  ```

  Args:
    value: A Tensor to be shuffled.
    seed: A Python integer. Used to create a random seed for the distribution.
      See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
    name: A name for the operation (optional).

  Returns:
    A tensor of same shape and type as `value`, shuffled along its first
    dimension.
  """
  seed1, seed2 = random_seed.get_seed(seed)
  return gen_random_ops._random_shuffle(value, seed=seed1, seed2=seed2,
                                        name=name)


ops.NoGradient("RandomUniform")


@ops.RegisterShape("TruncatedNormal")
@ops.RegisterShape("RandomStandardNormal")
@ops.RegisterShape("RandomUniform")
def _RandomShape(op):
  shape_val = tensor_util.ConstantValue(op.inputs[0])
  if shape_val is not None:
    return [tensor_shape.TensorShape(shape_val.tolist())]
  else:
    shape_shape = op.inputs[0].get_shape().with_rank_at_most(1)
    return [tensor_shape.unknown_shape(ndims=shape_shape.num_elements())]


ops.RegisterShape("RandomShuffle")(common_shapes.unchanged_shape)