aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 18:02:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 18:02:42 -0700
commit5dcca3baca11de0687747e9b5ad8854b77fd097d (patch)
tree4eb40a90582b78285963bff749953dafd2feed03 /tensorflow/core/ops
parent213d76a6ed77a696883502c53a3a4f81d2ee4042 (diff)
parent1e104d80826fed95f9fad6f07f68e35cae3527b2 (diff)
Merge pull request #22386 from girving:stateless
PiperOrigin-RevId: 215995215
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc53
1 files changed, 32 insertions, 21 deletions
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 742709fb18..f919a21d60 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -19,42 +19,55 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-static Status StatelessShape(shape_inference::InferenceContext* context) {
+static Status StatelessShape(InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
- TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused;
- TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shape
ShapeHandle out;
- TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
- context->set_output(0, out);
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
return Status::OK();
}
-#define REGISTER_STATELESS_OP(name) \
- REGISTER_OP(name) \
- .Input("shape: T") \
- .Input("seed: Tseed") \
- .Output("output: dtype") \
- .Attr("dtype: {half,float,double} = DT_FLOAT") \
- .Attr("T: {int32, int64} = DT_INT32") \
- .Attr("Tseed: {int32, int64} = DT_INT64") \
+#define REGISTER_STATELESS_OP(name) \
+ REGISTER_OP(name) \
+ .Input("shape: T") \
+ .Input("seed: Tseed") \
+ .Output("output: dtype") \
+ .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
+ .Attr("T: {int32, int64} = DT_INT32") \
+ .Attr("Tseed: {int32, int64} = DT_INT64") \
.SetShapeFn(StatelessShape)
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomUniform");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomNormal");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
-// This op is exposed through contrib/stateless only. The interface may change.
+#undef REGISTER_STATELESS_OP
+
+REGISTER_OP("StatelessRandomUniformInt")
+ .Input("shape: T")
+ .Input("seed: Tseed")
+ .Input("minval: dtype")
+ .Input("maxval: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: {int32, int64}")
+ .Attr("T: {int32, int64}")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return StatelessShape(c);
+ });
+
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")
@@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK();
});
-#undef REGISTER_STATELESS_OP
-
} // namespace tensorflow