aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-11 18:09:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 18:11:58 -0700
commit70d99359fcb9aa9efa955fab06227373c734728b (patch)
treee5c9d1c6cbed02be0a352f85f64b525c3dddcbe9 /tensorflow/core/kernels
parent1a721ecd9a9992d48c0deb3008b1fc8df297d300 (diff)
Add `tf.contrib.stateless.stateless_multinomial()`.
This is a starting point for Dataset-compatible weighted sampling across a list of datasets. PiperOrigin-RevId: 192540412
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc131
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc68
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.h34
4 files changed, 181 insertions, 53 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1018e8d25c..e2af540dac 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4323,6 +4323,7 @@ tf_kernel_library(
deps = [
":random_op",
":random_ops",
+ ":stateless_random_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index d086abb247..7a64788448 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@@ -127,18 +128,16 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
} // namespace functor
+namespace {
+
// Samples from a multinomial distribution.
template <typename Device, typename T, typename OutputType>
class MultinomialOp : public OpKernel {
public:
- explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, generator_.Init(context));
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& logits_t = ctx->input(0);
- const Tensor& num_samples_t = ctx->input(1);
+ explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void DoCompute(OpKernelContext* ctx, const Tensor& logits_t,
+ const Tensor& num_samples_t, GuardedPhiloxRandom* generator) {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()),
errors::InvalidArgument("logits should be a matrix, got shape ",
logits_t.shape().DebugString()));
@@ -194,7 +193,7 @@ class MultinomialOp : public OpKernel {
// CPU generates doubles = 2 samples per number.
if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
auto rng =
- generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
+ generator->ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
functor::MultinomialFunctor<Device, T, OutputType>()(
ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
@@ -202,24 +201,38 @@ class MultinomialOp : public OpKernel {
samples_t->matrix<OutputType>());
}
}
+};
+
+template <typename Device, typename T, typename OutputType>
+class StatefulMultinomialOp : public MultinomialOp<Device, T, OutputType> {
+ public:
+ explicit StatefulMultinomialOp(OpKernelConstruction* ctx)
+ : MultinomialOp<Device, T, OutputType>(ctx) {
+ OP_REQUIRES_OK(ctx, generator_.Init(ctx));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& logits_t = ctx->input(0);
+ const Tensor& num_samples_t = ctx->input(1);
+ this->DoCompute(ctx, logits_t, num_samples_t, &generator_);
+ }
private:
GuardedPhiloxRandom generator_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp);
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("Multinomial") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("T") \
- .TypeConstraint("output_dtype", DT_INT32), \
- MultinomialOp<CPUDevice, TYPE, int32>); \
- REGISTER_KERNEL_BUILDER(Name("Multinomial") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("T") \
- .TypeConstraint("output_dtype", DT_INT64), \
- MultinomialOp<CPUDevice, TYPE, int64>);
+// TODO(b/77906027): Add a TPU implementation.
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ StatefulMultinomialOp<CPUDevice, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ StatefulMultinomialOp<CPUDevice, TYPE, int64>);
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
@@ -233,13 +246,83 @@ TF_CALL_double(REGISTER);
.HostMemory("num_samples") \
.TypeConstraint<TYPE>("T") \
.TypeConstraint("output_dtype", DT_INT32), \
- MultinomialOp<GPUDevice, TYPE, int32>) \
+ StatefulMultinomialOp<GPUDevice, TYPE, int32>) \
REGISTER_KERNEL_BUILDER(Name("Multinomial") \
.Device(DEVICE_GPU) \
.HostMemory("num_samples") \
.TypeConstraint<TYPE>("T") \
.TypeConstraint("output_dtype", DT_INT64), \
- MultinomialOp<GPUDevice, TYPE, int64>)
+ StatefulMultinomialOp<GPUDevice, TYPE, int64>)
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+template <typename Device, typename T, typename OutputType>
+class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> {
+ public:
+ explicit StatelessMultinomialOp(OpKernelConstruction* ctx)
+ : MultinomialOp<Device, T, OutputType>(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& logits_t = ctx->input(0);
+ const Tensor& num_samples_t = ctx->input(1);
+
+ const Tensor& seed_t = ctx->input(2);
+ OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
+ errors::InvalidArgument("seed must have shape [2], not ",
+ seed_t.shape().DebugString()));
+
+ random::PhiloxRandom::Key key;
+ random::PhiloxRandom::ResultType counter;
+ OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
+
+ GuardedPhiloxRandom generator;
+ generator.Init(counter, key);
+
+ this->DoCompute(ctx, logits_t, num_samples_t, &generator);
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ StatelessMultinomialOp<CPUDevice, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ StatelessMultinomialOp<CPUDevice, TYPE, int64>);
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+
+#if GOOGLE_CUDA
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ StatelessMultinomialOp<GPUDevice, TYPE, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ StatelessMultinomialOp<GPUDevice, TYPE, int64>)
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
@@ -248,4 +331,6 @@ TF_CALL_double(REGISTER);
#endif // GOOGLE_CUDA
+} // end namespace
+
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index 88fcf542fb..eab176c7fb 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -27,6 +27,41 @@ namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
+Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key,
+ random::PhiloxRandom::ResultType* out_counter) {
+ // Grab the two seeds
+ uint64 seed0;
+ uint64 seed1;
+ if (seed.dtype() == DT_INT32) {
+ const auto seed_vals = seed.flat<int32>();
+ seed0 = internal::SubtleMustCopy(seed_vals(0));
+ seed1 = internal::SubtleMustCopy(seed_vals(1));
+ } else if (seed.dtype() == DT_INT64) {
+ const auto seed_vals = seed.flat<int64>();
+ seed0 = internal::SubtleMustCopy(seed_vals(0));
+ seed1 = internal::SubtleMustCopy(seed_vals(1));
+ } else {
+ return errors::InvalidArgument("Invalid seed type: ",
+ DataTypeString(seed.dtype()));
+ }
+
+ // Scramble the seeds so that the user doesn't need to worry about which
+ // part of the seed needs to be strong.
+ (*out_key)[0] = 0x3ec8f720;
+ (*out_key)[1] = 0x02461e29;
+ (*out_counter)[0] = static_cast<uint32>(seed0);
+ (*out_counter)[1] = static_cast<uint32>(seed0 >> 32);
+ (*out_counter)[2] = static_cast<uint32>(seed1);
+ (*out_counter)[3] = static_cast<uint32>(seed1 >> 32);
+ const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
+ (*out_key)[0] = mix[0];
+ (*out_key)[1] = mix[1];
+ (*out_counter)[0] = (*out_counter)[1] = 0;
+ (*out_counter)[2] = mix[2];
+ (*out_counter)[3] = mix[3];
+ return Status::OK();
+}
+
namespace {
class StatelessRandomOpBase : public OpKernel {
@@ -49,36 +84,9 @@ class StatelessRandomOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
if (shape.num_elements() == 0) return;
- // Grab the two seeds
- uint64 seed0;
- uint64 seed1;
- if (context->input_dtype(1) == DT_INT32) {
- const auto seed = seed_t.flat<int32>();
- seed0 = internal::SubtleMustCopy(seed(0));
- seed1 = internal::SubtleMustCopy(seed(1));
- } else {
- CHECK_EQ(DT_INT64, context->input_dtype(1));
- const auto seed = seed_t.flat<int64>();
- seed0 = internal::SubtleMustCopy(seed(0));
- seed1 = internal::SubtleMustCopy(seed(1));
- }
-
- // Scramble the seeds so that the user doesn't need to worry about which
- // part of the seed needs to be strong.
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
- key[0] = 0x3ec8f720;
- key[1] = 0x02461e29;
- counter[0] = static_cast<uint32>(seed0);
- counter[1] = static_cast<uint32>(seed0 >> 32);
- counter[2] = static_cast<uint32>(seed1);
- counter[3] = static_cast<uint32>(seed1 >> 32);
- const auto mix = random::PhiloxRandom(counter, key)();
- key[0] = mix[0];
- key[1] = mix[1];
- counter[0] = counter[1] = 0;
- counter[2] = mix[2];
- counter[3] = mix[3];
+ OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
// Fill in the random numbers
Fill(context, random::PhiloxRandom(counter, key), output);
@@ -105,8 +113,6 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-} // namespace
-
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniform") \
@@ -176,4 +182,6 @@ TF_CALL_double(REGISTER);
#endif // GOOGLE_CUDA
+} // namespace
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/stateless_random_ops.h b/tensorflow/core/kernels/stateless_random_ops.h
new file mode 100644
index 0000000000..bcd29c4873
--- /dev/null
+++ b/tensorflow/core/kernels/stateless_random_ops.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+
+// Generates a key and counter that can be used to seed a PhiloxRandom,
+// generator, based on the seed value in `seed_t`.
+//
+// REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}.
+// `out_key` and `out_counter` must be non-null.
+Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key,
+ random::PhiloxRandom::ResultType* out_counter);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_