diff options
author | 2018-09-19 09:33:19 -0700 | |
---|---|---|
committer | 2018-10-05 15:02:11 -0700 | |
commit | 1e104d80826fed95f9fad6f07f68e35cae3527b2 (patch) | |
tree | bb934f3e78f55f53fadbdba8408e550f1f61e171 /tensorflow/core/ops | |
parent | 80c9eec9b2475630f83a596f77a906c8075f8e6c (diff) |
Expand stateless random generators to match their stateful cousins
stateless_random_uniform now take minval+maxval and handles ints,
and stateless_normal/stateless_truncated_normal take mean+stddev.
Additionally, all of the stateless functions now have proper doc
strings.
This is step one of moving stateless random numbers out of contrib.
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/stateless_random_ops.cc | 53 |
1 files changed, 32 insertions, 21 deletions
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 742709fb18..f919a21d60 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -19,42 +19,55 @@ limitations under the License. namespace tensorflow { using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShape(shape_inference::InferenceContext* context) { +static Status StatelessShape(InferenceContext* c) { // Check seed shape ShapeHandle seed; - TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); DimensionHandle unused; - TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); // Set output shape ShapeHandle out; - TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out)); - context->set_output(0, out); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); return Status::OK(); } -#define REGISTER_STATELESS_OP(name) \ - REGISTER_OP(name) \ - .Input("shape: T") \ - .Input("seed: Tseed") \ - .Output("output: dtype") \ - .Attr("dtype: {half,float,double} = DT_FLOAT") \ - .Attr("T: {int32, int64} = DT_INT32") \ - .Attr("Tseed: {int32, int64} = DT_INT64") \ +#define REGISTER_STATELESS_OP(name) \ + REGISTER_OP(name) \ + .Input("shape: T") \ + .Input("seed: Tseed") \ + .Output("output: dtype") \ + .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ + .Attr("T: {int32, int64} = DT_INT32") \ + .Attr("Tseed: {int32, int64} = DT_INT64") \ .SetShapeFn(StatelessShape) -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomUniform"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomNormal"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessTruncatedNormal"); -// This op is exposed through contrib/stateless only. The interface may change. +#undef REGISTER_STATELESS_OP + +REGISTER_OP("StatelessRandomUniformInt") + .Input("shape: T") + .Input("seed: Tseed") + .Input("minval: dtype") + .Input("maxval: dtype") + .Output("output: dtype") + .Attr("dtype: {int32, int64}") + .Attr("T: {int32, int64}") + .Attr("Tseed: {int32, int64} = DT_INT64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return StatelessShape(c); + }); + REGISTER_OP("StatelessMultinomial") .Input("logits: T") .Input("num_samples: int32") @@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial") return Status::OK(); }); -#undef REGISTER_STATELESS_OP - } // namespace tensorflow |