aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <irving@naml.us>2018-09-19 09:33:19 -0700
committerGravatar Geoffrey Irving <irving@naml.us>2018-10-05 15:02:11 -0700
commit1e104d80826fed95f9fad6f07f68e35cae3527b2 (patch)
treebb934f3e78f55f53fadbdba8408e550f1f61e171 /tensorflow/core
parent80c9eec9b2475630f83a596f77a906c8075f8e6c (diff)
Expand stateless random generators to match their stateful cousins
stateless_random_uniform now take minval+maxval and handles ints, and stateless_normal/stateless_truncated_normal take mean+stddev. Additionally, all of the stateless functions now have proper doc strings. This is step one of moving stateless random numbers out of contrib.
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt46
-rw-r--r--tensorflow/core/kernels/random_op.cc34
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc155
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc53
4 files changed, 192 insertions, 96 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
new file mode 100644
index 0000000000..b6a6dbdf54
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
@@ -0,0 +1,46 @@
+op {
+ graph_op_name: "StatelessRandomUniformInt"
+ visibility: HIDDEN
+ in_arg {
+ name: "shape"
+ description: <<END
+The shape of the output tensor.
+END
+ }
+ in_arg {
+ name: "seed"
+ description: <<END
+2 seeds (shape [2]).
+END
+ }
+ in_arg {
+ name: "minval"
+ description: <<END
+Minimum value (inclusive, scalar).
+END
+ }
+ in_arg {
+ name: "maxval"
+ description: <<END
+Maximum value (exclusive, scalar).
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Random values with specified shape.
+END
+ }
+ attr {
+ name: "dtype"
+ description: <<END
+The type of the output.
+END
+ }
+ summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
+ description: <<END
+The generated values follow a uniform distribution in the range `[minval, maxval)`.
+
+The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
+END
+}
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 04a53697c0..3810d817ca 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
TF_CALL_half(REGISTER);
@@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<int32>("T") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
TF_CALL_half(REGISTER);
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index eab176c7fb..925f5291a6 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- CPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+template <typename Device, typename IntType>
+class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
+ public:
+ using StatelessRandomOpBase::StatelessRandomOpBase;
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+ void Fill(OpKernelContext* context, random::PhiloxRandom random,
+ Tensor* output) override {
+ const Tensor& minval = context->input(2);
+ const Tensor& maxval = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
+ errors::InvalidArgument("minval must be 0-D, got shape ",
+ minval.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
+ errors::InvalidArgument("maxval must be 0-D, got shape ",
+ maxval.shape().DebugString()));
+
+ // Verify that minval < maxval. Note that we'll never reach this point for
+ // empty output. Zero impossible things are fine.
+ const auto lo = minval.scalar<IntType>()();
+ const auto hi = maxval.scalar<IntType>()();
+ OP_REQUIRES(
+ context, lo < hi,
+ errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
+
+ // Build distribution
+ typedef random::UniformDistribution<random::PhiloxRandom, IntType>
+ Distribution;
+ Distribution dist(lo, hi);
+
+ auto flat = output->flat<IntType>();
+ // Reuse the compute kernels from the stateful random ops
+ functor::FillPhiloxRandom<Device, Distribution>()(
+ context, context->eigen_device<Device>(), random, flat.data(),
+ flat.size(), dist);
+ }
+};
-#undef REGISTER
+#define REGISTER(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomUniform") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessTruncatedNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp< \
+ DEVICE##Device, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+
+#define REGISTER_INT(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
+
+#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
+#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
+#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
+#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
+
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_int32(REGISTER_INT_CPU);
+TF_CALL_int64(REGISTER_INT_CPU);
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- GPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+TF_CALL_int32(REGISTER_INT_GPU);
+TF_CALL_int64(REGISTER_INT_GPU);
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+#endif // GOOGLE_CUDA
#undef REGISTER
-
-#endif // GOOGLE_CUDA
+#undef REGISTER_INT
+#undef REGISTER_CPU
+#undef REGISTER_GPU
+#undef REGISTER_INT_CPU
+#undef REGISTER_INT_GPU
} // namespace
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 742709fb18..f919a21d60 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -19,42 +19,55 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-static Status StatelessShape(shape_inference::InferenceContext* context) {
+static Status StatelessShape(InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
- TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused;
- TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shape
ShapeHandle out;
- TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
- context->set_output(0, out);
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
return Status::OK();
}
-#define REGISTER_STATELESS_OP(name) \
- REGISTER_OP(name) \
- .Input("shape: T") \
- .Input("seed: Tseed") \
- .Output("output: dtype") \
- .Attr("dtype: {half,float,double} = DT_FLOAT") \
- .Attr("T: {int32, int64} = DT_INT32") \
- .Attr("Tseed: {int32, int64} = DT_INT64") \
+#define REGISTER_STATELESS_OP(name) \
+ REGISTER_OP(name) \
+ .Input("shape: T") \
+ .Input("seed: Tseed") \
+ .Output("output: dtype") \
+ .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
+ .Attr("T: {int32, int64} = DT_INT32") \
+ .Attr("Tseed: {int32, int64} = DT_INT64") \
.SetShapeFn(StatelessShape)
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomUniform");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomNormal");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
-// This op is exposed through contrib/stateless only. The interface may change.
+#undef REGISTER_STATELESS_OP
+
+REGISTER_OP("StatelessRandomUniformInt")
+ .Input("shape: T")
+ .Input("seed: Tseed")
+ .Input("minval: dtype")
+ .Input("maxval: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: {int32, int64}")
+ .Attr("T: {int32, int64}")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return StatelessShape(c);
+ });
+
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")
@@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK();
});
-#undef REGISTER_STATELESS_OP
-
} // namespace tensorflow