diff options
author | 2018-04-11 18:09:42 -0700 | |
---|---|---|
committer | 2018-04-11 18:11:58 -0700 | |
commit | 70d99359fcb9aa9efa955fab06227373c734728b (patch) | |
tree | e5c9d1c6cbed02be0a352f85f64b525c3dddcbe9 /tensorflow/core/kernels | |
parent | 1a721ecd9a9992d48c0deb3008b1fc8df297d300 (diff) |
Add `tf.contrib.stateless.stateless_multinomial()`.
This is a starting point for Dataset-compatible weighted sampling across a list of datasets.
PiperOrigin-RevId: 192540412
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/multinomial_op.cc | 131 | ||||
-rw-r--r-- | tensorflow/core/kernels/stateless_random_ops.cc | 68 | ||||
-rw-r--r-- | tensorflow/core/kernels/stateless_random_ops.h | 34 |
4 files changed, 181 insertions, 53 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 1018e8d25c..e2af540dac 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4323,6 +4323,7 @@ tf_kernel_library( deps = [ ":random_op", ":random_ops", + ":stateless_random_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc index d086abb247..7a64788448 100644 --- a/tensorflow/core/kernels/multinomial_op.cc +++ b/tensorflow/core/kernels/multinomial_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/stateless_random_ops.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/util/guarded_philox_random.h" @@ -127,18 +128,16 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> { } // namespace functor +namespace { + // Samples from a multinomial distribution. template <typename Device, typename T, typename OutputType> class MultinomialOp : public OpKernel { public: - explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, generator_.Init(context)); - } - - void Compute(OpKernelContext* ctx) override { - const Tensor& logits_t = ctx->input(0); - const Tensor& num_samples_t = ctx->input(1); + explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {} + void DoCompute(OpKernelContext* ctx, const Tensor& logits_t, + const Tensor& num_samples_t, GuardedPhiloxRandom* generator) { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()), errors::InvalidArgument("logits should be a matrix, got shape ", logits_t.shape().DebugString())); @@ -194,7 +193,7 @@ class MultinomialOp : public OpKernel { // CPU generates doubles = 2 samples per number. if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2; auto rng = - generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256); + generator->ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256); functor::MultinomialFunctor<Device, T, OutputType>()( ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(), noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(), @@ -202,24 +201,38 @@ class MultinomialOp : public OpKernel { samples_t->matrix<OutputType>()); } } +}; + +template <typename Device, typename T, typename OutputType> +class StatefulMultinomialOp : public MultinomialOp<Device, T, OutputType> { + public: + explicit StatefulMultinomialOp(OpKernelConstruction* ctx) + : MultinomialOp<Device, T, OutputType>(ctx) { + OP_REQUIRES_OK(ctx, generator_.Init(ctx)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& logits_t = ctx->input(0); + const Tensor& num_samples_t = ctx->input(1); + this->DoCompute(ctx, logits_t, num_samples_t, &generator_); + } private: GuardedPhiloxRandom generator_; - - TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp); }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("Multinomial") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<TYPE>("T") \ - .TypeConstraint("output_dtype", DT_INT32), \ - MultinomialOp<CPUDevice, TYPE, int32>); \ - REGISTER_KERNEL_BUILDER(Name("Multinomial") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<TYPE>("T") \ - .TypeConstraint("output_dtype", DT_INT64), \ - MultinomialOp<CPUDevice, TYPE, int64>); +// TODO(b/77906027): Add a TPU implementation. +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + StatefulMultinomialOp<CPUDevice, TYPE, int32>); \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + StatefulMultinomialOp<CPUDevice, TYPE, int64>); TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); @@ -233,13 +246,83 @@ TF_CALL_double(REGISTER); .HostMemory("num_samples") \ .TypeConstraint<TYPE>("T") \ .TypeConstraint("output_dtype", DT_INT32), \ - MultinomialOp<GPUDevice, TYPE, int32>) \ + StatefulMultinomialOp<GPUDevice, TYPE, int32>) \ REGISTER_KERNEL_BUILDER(Name("Multinomial") \ .Device(DEVICE_GPU) \ .HostMemory("num_samples") \ .TypeConstraint<TYPE>("T") \ .TypeConstraint("output_dtype", DT_INT64), \ - MultinomialOp<GPUDevice, TYPE, int64>) + StatefulMultinomialOp<GPUDevice, TYPE, int64>) + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#endif // GOOGLE_CUDA + +template <typename Device, typename T, typename OutputType> +class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> { + public: + explicit StatelessMultinomialOp(OpKernelConstruction* ctx) + : MultinomialOp<Device, T, OutputType>(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& logits_t = ctx->input(0); + const Tensor& num_samples_t = ctx->input(1); + + const Tensor& seed_t = ctx->input(2); + OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_t.shape().DebugString())); + + random::PhiloxRandom::Key key; + random::PhiloxRandom::ResultType counter; + OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter)); + + GuardedPhiloxRandom generator; + generator.Init(counter, key); + + this->DoCompute(ctx, logits_t, num_samples_t, &generator); + } + + private: + GuardedPhiloxRandom generator_; +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + StatelessMultinomialOp<CPUDevice, TYPE, int32>); \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + StatelessMultinomialOp<CPUDevice, TYPE, int64>); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#if GOOGLE_CUDA +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_samples") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + StatelessMultinomialOp<GPUDevice, TYPE, int32>) \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_samples") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + StatelessMultinomialOp<GPUDevice, TYPE, int64>) TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); @@ -248,4 +331,6 @@ TF_CALL_double(REGISTER); #endif // GOOGLE_CUDA +} // end namespace + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index 88fcf542fb..eab176c7fb 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -27,6 +27,41 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; +Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter) { + // Grab the two seeds + uint64 seed0; + uint64 seed1; + if (seed.dtype() == DT_INT32) { + const auto seed_vals = seed.flat<int32>(); + seed0 = internal::SubtleMustCopy(seed_vals(0)); + seed1 = internal::SubtleMustCopy(seed_vals(1)); + } else if (seed.dtype() == DT_INT64) { + const auto seed_vals = seed.flat<int64>(); + seed0 = internal::SubtleMustCopy(seed_vals(0)); + seed1 = internal::SubtleMustCopy(seed_vals(1)); + } else { + return errors::InvalidArgument("Invalid seed type: ", + DataTypeString(seed.dtype())); + } + + // Scramble the seeds so that the user doesn't need to worry about which + // part of the seed needs to be strong. + (*out_key)[0] = 0x3ec8f720; + (*out_key)[1] = 0x02461e29; + (*out_counter)[0] = static_cast<uint32>(seed0); + (*out_counter)[1] = static_cast<uint32>(seed0 >> 32); + (*out_counter)[2] = static_cast<uint32>(seed1); + (*out_counter)[3] = static_cast<uint32>(seed1 >> 32); + const auto mix = random::PhiloxRandom(*out_counter, *out_key)(); + (*out_key)[0] = mix[0]; + (*out_key)[1] = mix[1]; + (*out_counter)[0] = (*out_counter)[1] = 0; + (*out_counter)[2] = mix[2]; + (*out_counter)[3] = mix[3]; + return Status::OK(); +} + namespace { class StatelessRandomOpBase : public OpKernel { @@ -49,36 +84,9 @@ class StatelessRandomOpBase : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); if (shape.num_elements() == 0) return; - // Grab the two seeds - uint64 seed0; - uint64 seed1; - if (context->input_dtype(1) == DT_INT32) { - const auto seed = seed_t.flat<int32>(); - seed0 = internal::SubtleMustCopy(seed(0)); - seed1 = internal::SubtleMustCopy(seed(1)); - } else { - CHECK_EQ(DT_INT64, context->input_dtype(1)); - const auto seed = seed_t.flat<int64>(); - seed0 = internal::SubtleMustCopy(seed(0)); - seed1 = internal::SubtleMustCopy(seed(1)); - } - - // Scramble the seeds so that the user doesn't need to worry about which - // part of the seed needs to be strong. random::PhiloxRandom::Key key; random::PhiloxRandom::ResultType counter; - key[0] = 0x3ec8f720; - key[1] = 0x02461e29; - counter[0] = static_cast<uint32>(seed0); - counter[1] = static_cast<uint32>(seed0 >> 32); - counter[2] = static_cast<uint32>(seed1); - counter[3] = static_cast<uint32>(seed1 >> 32); - const auto mix = random::PhiloxRandom(counter, key)(); - key[0] = mix[0]; - key[1] = mix[1]; - counter[0] = counter[1] = 0; - counter[2] = mix[2]; - counter[3] = mix[3]; + OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter)); // Fill in the random numbers Fill(context, random::PhiloxRandom(counter, key), output); @@ -105,8 +113,6 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -} // namespace - #define REGISTER(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("StatelessRandomUniform") \ @@ -176,4 +182,6 @@ TF_CALL_double(REGISTER); #endif // GOOGLE_CUDA +} // namespace + } // namespace tensorflow diff --git a/tensorflow/core/kernels/stateless_random_ops.h b/tensorflow/core/kernels/stateless_random_ops.h new file mode 100644 index 0000000000..bcd29c4873 --- /dev/null +++ b/tensorflow/core/kernels/stateless_random_ops.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +// Generates a key and counter that can be used to seed a PhiloxRandom, +// generator, based on the seed value in `seed_t`. +// +// REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}. +// `out_key` and `out_counter` must be non-null. +Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ |