aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/random_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-12 06:12:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-12 07:21:37 -0700
commit98bdd76ecb881f46b77fa57c85a023ec5799ac52 (patch)
tree258384d77a710c76d87f9567d48dd1c39791ddab /tensorflow/core/ops/random_ops.cc
parenta3f539fcad30f86f3354241925c07b77934702b8 (diff)
Add a ParameterizedTruncatedNormalOp to eventually replace the current TruncatedNormalOp. It takes a matrix of batched parameters (mean, stdev, minval, maxval).
Once the GPU functor is added, we can eventually use this op to implement tf.truncated_normal with an optional minval and maxval, and support for batches, and remove the existing TruncatedNormalOp. Change: 127196322
Diffstat (limited to 'tensorflow/core/ops/random_ops.cc')
-rw-r--r--tensorflow/core/ops/random_ops.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 7151af2fb4..16dad57643 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -96,6 +96,39 @@ seed2: A second seed to avoid seed collision.
output: A tensor of the specified shape filled with random normal values.
)doc");
+REGISTER_OP("ParameterizedTruncatedNormal")
+ .Input("shape: T")
+ .Input("means: dtype")
+ .Input("stdevs: dtype")
+ .Input("minvals: dtype")
+ .Input("maxvals: dtype")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {half,float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a normal distribution. The parameters may each be a
+scalar which applies to the entire output, or a vector of length shape[0] which
+stores the parameters for each batch.
+
+shape: The shape of the output tensor. Batches are indexed by the 0th dimension.
+means: The mean parameter of each batch.
+stdevs: The standard deviation parameter of each batch. Must be greater than 0.
+minvals: The minimum cutoff. May be -infinity.
+maxvals: The maximum cutoff. May be +infinity, and must be more than the minval
+ for each batch.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A matrix of shape num_batches x samples_per_batch, filled with random
+ truncated normal values using the parameters for each row.
+)doc");
+
REGISTER_OP("TruncatedNormal")
.Input("shape: T")
.SetIsStateful()