From 70d99359fcb9aa9efa955fab06227373c734728b Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 11 Apr 2018 18:09:42 -0700 Subject: Add `tf.contrib.stateless.stateless_multinomial()`. This is a starting point for Dataset-compatible weighted sampling across a list of datasets. PiperOrigin-RevId: 192540412 --- .../base_api/api_def_StatelessMultinomial.pbtxt | 30 +++++ tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/multinomial_op.cc | 131 +++++++++++++++++---- tensorflow/core/kernels/stateless_random_ops.cc | 68 ++++++----- tensorflow/core/kernels/stateless_random_ops.h | 34 ++++++ tensorflow/core/ops/stateless_random_ops.cc | 28 ++++- tensorflow/core/util/guarded_philox_random.cc | 8 ++ tensorflow/core/util/guarded_philox_random.h | 2 + 8 files changed, 248 insertions(+), 54 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt create mode 100644 tensorflow/core/kernels/stateless_random_ops.h (limited to 'tensorflow/core') diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt new file mode 100644 index 0000000000..c4e6c1fddd --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt @@ -0,0 +1,30 @@ +op { + graph_op_name: "StatelessMultinomial" + in_arg { + name: "logits" + description: < { } // namespace functor +namespace { + // Samples from a multinomial distribution. template class MultinomialOp : public OpKernel { public: - explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, generator_.Init(context)); - } - - void Compute(OpKernelContext* ctx) override { - const Tensor& logits_t = ctx->input(0); - const Tensor& num_samples_t = ctx->input(1); + explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {} + void DoCompute(OpKernelContext* ctx, const Tensor& logits_t, + const Tensor& num_samples_t, GuardedPhiloxRandom* generator) { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()), errors::InvalidArgument("logits should be a matrix, got shape ", logits_t.shape().DebugString())); @@ -194,7 +193,7 @@ class MultinomialOp : public OpKernel { // CPU generates doubles = 2 samples per number. if (std::is_same::value) num_samples_ceil_4 *= 2; auto rng = - generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256); + generator->ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256); functor::MultinomialFunctor()( ctx, ctx->eigen_device(), logits_t.matrix(), noises.flat(), scores.flat(), scratch.flat(), @@ -202,24 +201,38 @@ class MultinomialOp : public OpKernel { samples_t->matrix()); } } +}; + +template +class StatefulMultinomialOp : public MultinomialOp { + public: + explicit StatefulMultinomialOp(OpKernelConstruction* ctx) + : MultinomialOp(ctx) { + OP_REQUIRES_OK(ctx, generator_.Init(ctx)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& logits_t = ctx->input(0); + const Tensor& num_samples_t = ctx->input(1); + this->DoCompute(ctx, logits_t, num_samples_t, &generator_); + } private: GuardedPhiloxRandom generator_; - - TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp); }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("Multinomial") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("output_dtype", DT_INT32), \ - MultinomialOp); \ - REGISTER_KERNEL_BUILDER(Name("Multinomial") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("output_dtype", DT_INT64), \ - MultinomialOp); +// TODO(b/77906027): Add a TPU implementation. +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + StatefulMultinomialOp); \ + REGISTER_KERNEL_BUILDER(Name("Multinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + StatefulMultinomialOp); TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); @@ -233,13 +246,83 @@ TF_CALL_double(REGISTER); .HostMemory("num_samples") \ .TypeConstraint("T") \ .TypeConstraint("output_dtype", DT_INT32), \ - MultinomialOp) \ + StatefulMultinomialOp) \ REGISTER_KERNEL_BUILDER(Name("Multinomial") \ .Device(DEVICE_GPU) \ .HostMemory("num_samples") \ .TypeConstraint("T") \ .TypeConstraint("output_dtype", DT_INT64), \ - MultinomialOp) + StatefulMultinomialOp) + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#endif // GOOGLE_CUDA + +template +class StatelessMultinomialOp : public MultinomialOp { + public: + explicit StatelessMultinomialOp(OpKernelConstruction* ctx) + : MultinomialOp(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& logits_t = ctx->input(0); + const Tensor& num_samples_t = ctx->input(1); + + const Tensor& seed_t = ctx->input(2); + OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_t.shape().DebugString())); + + random::PhiloxRandom::Key key; + random::PhiloxRandom::ResultType counter; + OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter)); + + GuardedPhiloxRandom generator; + generator.Init(counter, key); + + this->DoCompute(ctx, logits_t, num_samples_t, &generator); + } + + private: + GuardedPhiloxRandom generator_; +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + StatelessMultinomialOp); \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + StatelessMultinomialOp); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#if GOOGLE_CUDA +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_samples") \ + .HostMemory("seed") \ + .TypeConstraint("T") \ + .TypeConstraint("output_dtype", DT_INT32), \ + StatelessMultinomialOp) \ + REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_samples") \ + .HostMemory("seed") \ + .TypeConstraint("T") \ + .TypeConstraint("output_dtype", DT_INT64), \ + StatelessMultinomialOp) TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); @@ -248,4 +331,6 @@ TF_CALL_double(REGISTER); #endif // GOOGLE_CUDA +} // end namespace + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index 88fcf542fb..eab176c7fb 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -27,6 +27,41 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; +Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter) { + // Grab the two seeds + uint64 seed0; + uint64 seed1; + if (seed.dtype() == DT_INT32) { + const auto seed_vals = seed.flat(); + seed0 = internal::SubtleMustCopy(seed_vals(0)); + seed1 = internal::SubtleMustCopy(seed_vals(1)); + } else if (seed.dtype() == DT_INT64) { + const auto seed_vals = seed.flat(); + seed0 = internal::SubtleMustCopy(seed_vals(0)); + seed1 = internal::SubtleMustCopy(seed_vals(1)); + } else { + return errors::InvalidArgument("Invalid seed type: ", + DataTypeString(seed.dtype())); + } + + // Scramble the seeds so that the user doesn't need to worry about which + // part of the seed needs to be strong. + (*out_key)[0] = 0x3ec8f720; + (*out_key)[1] = 0x02461e29; + (*out_counter)[0] = static_cast(seed0); + (*out_counter)[1] = static_cast(seed0 >> 32); + (*out_counter)[2] = static_cast(seed1); + (*out_counter)[3] = static_cast(seed1 >> 32); + const auto mix = random::PhiloxRandom(*out_counter, *out_key)(); + (*out_key)[0] = mix[0]; + (*out_key)[1] = mix[1]; + (*out_counter)[0] = (*out_counter)[1] = 0; + (*out_counter)[2] = mix[2]; + (*out_counter)[3] = mix[3]; + return Status::OK(); +} + namespace { class StatelessRandomOpBase : public OpKernel { @@ -49,36 +84,9 @@ class StatelessRandomOpBase : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); if (shape.num_elements() == 0) return; - // Grab the two seeds - uint64 seed0; - uint64 seed1; - if (context->input_dtype(1) == DT_INT32) { - const auto seed = seed_t.flat(); - seed0 = internal::SubtleMustCopy(seed(0)); - seed1 = internal::SubtleMustCopy(seed(1)); - } else { - CHECK_EQ(DT_INT64, context->input_dtype(1)); - const auto seed = seed_t.flat(); - seed0 = internal::SubtleMustCopy(seed(0)); - seed1 = internal::SubtleMustCopy(seed(1)); - } - - // Scramble the seeds so that the user doesn't need to worry about which - // part of the seed needs to be strong. random::PhiloxRandom::Key key; random::PhiloxRandom::ResultType counter; - key[0] = 0x3ec8f720; - key[1] = 0x02461e29; - counter[0] = static_cast(seed0); - counter[1] = static_cast(seed0 >> 32); - counter[2] = static_cast(seed1); - counter[3] = static_cast(seed1 >> 32); - const auto mix = random::PhiloxRandom(counter, key)(); - key[0] = mix[0]; - key[1] = mix[1]; - counter[0] = counter[1] = 0; - counter[2] = mix[2]; - counter[3] = mix[3]; + OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter)); // Fill in the random numbers Fill(context, random::PhiloxRandom(counter, key), output); @@ -105,8 +113,6 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -} // namespace - #define REGISTER(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("StatelessRandomUniform") \ @@ -176,4 +182,6 @@ TF_CALL_double(REGISTER); #endif // GOOGLE_CUDA +} // namespace + } // namespace tensorflow diff --git a/tensorflow/core/kernels/stateless_random_ops.h b/tensorflow/core/kernels/stateless_random_ops.h new file mode 100644 index 0000000000..bcd29c4873 --- /dev/null +++ b/tensorflow/core/kernels/stateless_random_ops.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +// Generates a key and counter that can be used to seed a PhiloxRandom, +// generator, based on the seed value in `seed_t`. +// +// REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}. +// `out_key` and `out_counter` must be non-null. +Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 553850610a..742709fb18 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -29,7 +29,7 @@ static Status StatelessShape(shape_inference::InferenceContext* context) { TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused)); // Set output shape - shape_inference::ShapeHandle out; + ShapeHandle out; TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out)); context->set_output(0, out); return Status::OK(); @@ -54,6 +54,32 @@ REGISTER_STATELESS_OP("StatelessRandomNormal"); // This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessTruncatedNormal"); +// This op is exposed through contrib/stateless only. The interface may change. +REGISTER_OP("StatelessMultinomial") + .Input("logits: T") + .Input("num_samples: int32") + .Input("seed: Tseed") + .Output("output: output_dtype") + .Attr("T: realnumbertype") + .Attr("Tseed: {int32, int64} = DT_INT64") + .Attr("output_dtype: {int32, int64} = DT_INT64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Check seed shape + ShapeHandle seed; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &seed)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim)); + + ShapeHandle logits_shape; + ShapeHandle unused; + DimensionHandle num_samples; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples)); + c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples)); + return Status::OK(); + }); + #undef REGISTER_STATELESS_OP } // namespace tensorflow diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc index 2d1e9a293e..7c7ba4cef6 100644 --- a/tensorflow/core/util/guarded_philox_random.cc +++ b/tensorflow/core/util/guarded_philox_random.cc @@ -43,6 +43,14 @@ void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) { initialized_ = true; } +void GuardedPhiloxRandom::Init(random::PhiloxRandom::ResultType counter, + random::PhiloxRandom::Key key) { + CHECK(!initialized_); + mutex_lock lock(mu_); + generator_ = random::PhiloxRandom(counter, key); + initialized_ = true; +} + random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) { CHECK(initialized_); mutex_lock lock(mu_); diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h index 5b94a76777..44970eb949 100644 --- a/tensorflow/core/util/guarded_philox_random.h +++ b/tensorflow/core/util/guarded_philox_random.h @@ -49,6 +49,8 @@ class GuardedPhiloxRandom { // Initialize with given seeds. void Init(int64 seed, int64 seed2); + void Init(random::PhiloxRandom::ResultType counter, + random::PhiloxRandom::Key key); // Reserve a certain number of 128-bit samples. // This function is thread safe. The returned generator is valid for the -- cgit v1.2.3