aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/random_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/random_ops.cc')
-rw-r--r--tensorflow/core/ops/random_ops.cc21
1 files changed, 5 insertions, 16 deletions
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 7b2da9d8e6..392ac32010 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -23,17 +23,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-namespace {
-
-Status RandomShape(InferenceContext* c) {
- ShapeHandle out;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
- c->set_output(0, out);
- return Status::OK();
-}
-
-} // namepsace
-
REGISTER_OP("RandomUniform")
.Input("shape: T")
.SetIsStateful()
@@ -42,7 +31,7 @@ REGISTER_OP("RandomUniform")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a uniform distribution.
@@ -69,7 +58,7 @@ REGISTER_OP("RandomUniformInt")
.Attr("seed2: int = 0")
.Attr("Tout: {int32, int64}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random integers from a uniform distribution.
@@ -100,7 +89,7 @@ REGISTER_OP("RandomStandardNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a normal distribution.
@@ -128,7 +117,7 @@ REGISTER_OP("ParameterizedTruncatedNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a normal distribution. The parameters may each be a
scalar which applies to the entire output, or a vector of length shape[0] which
@@ -158,7 +147,7 @@ REGISTER_OP("TruncatedNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(RandomShape)
+ .SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
Outputs random values from a truncated normal distribution.