diff options
author | 2016-07-12 06:12:24 -0800 | |
---|---|---|
committer | 2016-07-12 07:21:37 -0700 | |
commit | 98bdd76ecb881f46b77fa57c85a023ec5799ac52 (patch) | |
tree | 258384d77a710c76d87f9567d48dd1c39791ddab /tensorflow/core/ops/random_ops.cc | |
parent | a3f539fcad30f86f3354241925c07b77934702b8 (diff) |
Add a ParameterizedTruncatedNormalOp to eventually replace the current TruncatedNormalOp. It takes a matrix of batched parameters (mean, stdev, minval, maxval).
Once the GPU functor is added, we can eventually use this op to implement tf.truncated_normal with an optional minval and maxval, and support for batches,
and remove the existing TruncatedNormalOp.
Change: 127196322
Diffstat (limited to 'tensorflow/core/ops/random_ops.cc')
-rw-r--r-- | tensorflow/core/ops/random_ops.cc | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index 7151af2fb4..16dad57643 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -96,6 +96,39 @@ seed2: A second seed to avoid seed collision. output: A tensor of the specified shape filled with random normal values. )doc"); +REGISTER_OP("ParameterizedTruncatedNormal") + .Input("shape: T") + .Input("means: dtype") + .Input("stdevs: dtype") + .Input("minvals: dtype") + .Input("maxvals: dtype") + .SetIsStateful() + .Output("output: dtype") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("dtype: {half,float,double}") + .Attr("T: {int32, int64}") + .Doc(R"doc( +Outputs random values from a normal distribution. The parameters may each be a +scalar which applies to the entire output, or a vector of length shape[0] which +stores the parameters for each batch. + +shape: The shape of the output tensor. Batches are indexed by the 0th dimension. +means: The mean parameter of each batch. +stdevs: The standard deviation parameter of each batch. Must be greater than 0. +minvals: The minimum cutoff. May be -infinity. +maxvals: The maximum cutoff. May be +infinity, and must be more than the minval + for each batch. +dtype: The type of the output. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A matrix of shape num_batches x samples_per_batch, filled with random + truncated normal values using the parameters for each row. +)doc"); + REGISTER_OP("TruncatedNormal") .Input("shape: T") .SetIsStateful() |