aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-11 18:09:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 18:11:58 -0700
commit70d99359fcb9aa9efa955fab06227373c734728b (patch)
treee5c9d1c6cbed02be0a352f85f64b525c3dddcbe9 /tensorflow/core
parent1a721ecd9a9992d48c0deb3008b1fc8df297d300 (diff)
Add `tf.contrib.stateless.stateless_multinomial()`.
This is a starting point for Dataset-compatible weighted sampling across a list of datasets. PiperOrigin-RevId: 192540412
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt30
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc131
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc68
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.h34
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc28
-rw-r--r--tensorflow/core/util/guarded_philox_random.cc8
-rw-r--r--tensorflow/core/util/guarded_philox_random.h2
8 files changed, 248 insertions, 54 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt
new file mode 100644
index 0000000000..c4e6c1fddd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessMultinomial.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "StatelessMultinomial"
+ in_arg {
+ name: "logits"
+ description: <<END
+2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]`
+represents the unnormalized log probabilities for all classes.
+END
+ }
+ in_arg {
+ name: "num_samples"
+ description: <<END
+0-D. Number of independent samples to draw for each row slice.
+END
+ }
+ in_arg {
+ name: "seed"
+ description: <<END
+2 seeds (shape [2]).
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]`
+contains the drawn class labels with range `[0, num_classes)`.
+END
+ }
+ summary: "Draws samples from a multinomial distribution."
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1018e8d25c..e2af540dac 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4323,6 +4323,7 @@ tf_kernel_library(
deps = [
":random_op",
":random_ops",
+ ":stateless_random_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index d086abb247..7a64788448 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@@ -127,18 +128,16 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
} // namespace functor
+namespace {
+
// Samples from a multinomial distribution.
template <typename Device, typename T, typename OutputType>
class MultinomialOp : public OpKernel {
public:
- explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, generator_.Init(context));
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& logits_t = ctx->input(0);
- const Tensor& num_samples_t = ctx->input(1);
+ explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void DoCompute(OpKernelContext* ctx, const Tensor& logits_t,
+ const Tensor& num_samples_t, GuardedPhiloxRandom* generator) {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()),
errors::InvalidArgument("logits should be a matrix, got shape ",
logits_t.shape().DebugString()));
@@ -194,7 +193,7 @@ class MultinomialOp : public OpKernel {
// CPU generates doubles = 2 samples per number.
if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
auto rng =
- generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
+ generator->ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
functor::MultinomialFunctor<Device, T, OutputType>()(
ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
@@ -202,24 +201,38 @@ class MultinomialOp : public OpKernel {
samples_t->matrix<OutputType>());
}
}
+};
+
+template <typename Device, typename T, typename OutputType>
+class StatefulMultinomialOp : public MultinomialOp<Device, T, OutputType> {
+ public:
+ explicit StatefulMultinomialOp(OpKernelConstruction* ctx)
+ : MultinomialOp<Device, T, OutputType>(ctx) {
+ OP_REQUIRES_OK(ctx, generator_.Init(ctx));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& logits_t = ctx->input(0);
+ const Tensor& num_samples_t = ctx->input(1);
+ this->DoCompute(ctx, logits_t, num_samples_t, &generator_);
+ }
private:
GuardedPhiloxRandom generator_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp);
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("Multinomial") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("T") \
- .TypeConstraint("output_dtype", DT_INT32), \
- MultinomialOp<CPUDevice, TYPE, int32>); \
- REGISTER_KERNEL_BUILDER(Name("Multinomial") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("T") \
- .TypeConstraint("output_dtype", DT_INT64), \
- MultinomialOp<CPUDevice, TYPE, int64>);
+// TODO(b/77906027): Add a TPU implementation.
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ StatefulMultinomialOp<CPUDevice, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ StatefulMultinomialOp<CPUDevice, TYPE, int64>);
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
@@ -233,13 +246,83 @@ TF_CALL_double(REGISTER);
.HostMemory("num_samples") \
.TypeConstraint<TYPE>("T") \
.TypeConstraint("output_dtype", DT_INT32), \
- MultinomialOp<GPUDevice, TYPE, int32>) \
+ StatefulMultinomialOp<GPUDevice, TYPE, int32>) \
REGISTER_KERNEL_BUILDER(Name("Multinomial") \
.Device(DEVICE_GPU) \
.HostMemory("num_samples") \
.TypeConstraint<TYPE>("T") \
.TypeConstraint("output_dtype", DT_INT64), \
- MultinomialOp<GPUDevice, TYPE, int64>)
+ StatefulMultinomialOp<GPUDevice, TYPE, int64>)
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+template <typename Device, typename T, typename OutputType>
+class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> {
+ public:
+ explicit StatelessMultinomialOp(OpKernelConstruction* ctx)
+ : MultinomialOp<Device, T, OutputType>(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& logits_t = ctx->input(0);
+ const Tensor& num_samples_t = ctx->input(1);
+
+ const Tensor& seed_t = ctx->input(2);
+ OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
+ errors::InvalidArgument("seed must have shape [2], not ",
+ seed_t.shape().DebugString()));
+
+ random::PhiloxRandom::Key key;
+ random::PhiloxRandom::ResultType counter;
+ OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
+
+ GuardedPhiloxRandom generator;
+ generator.Init(counter, key);
+
+ this->DoCompute(ctx, logits_t, num_samples_t, &generator);
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ StatelessMultinomialOp<CPUDevice, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ StatelessMultinomialOp<CPUDevice, TYPE, int64>);
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+
+#if GOOGLE_CUDA
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ StatelessMultinomialOp<GPUDevice, TYPE, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ StatelessMultinomialOp<GPUDevice, TYPE, int64>)
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
@@ -248,4 +331,6 @@ TF_CALL_double(REGISTER);
#endif // GOOGLE_CUDA
+} // end namespace
+
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index 88fcf542fb..eab176c7fb 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -27,6 +27,41 @@ namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
+Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key,
+ random::PhiloxRandom::ResultType* out_counter) {
+ // Grab the two seeds
+ uint64 seed0;
+ uint64 seed1;
+ if (seed.dtype() == DT_INT32) {
+ const auto seed_vals = seed.flat<int32>();
+ seed0 = internal::SubtleMustCopy(seed_vals(0));
+ seed1 = internal::SubtleMustCopy(seed_vals(1));
+ } else if (seed.dtype() == DT_INT64) {
+ const auto seed_vals = seed.flat<int64>();
+ seed0 = internal::SubtleMustCopy(seed_vals(0));
+ seed1 = internal::SubtleMustCopy(seed_vals(1));
+ } else {
+ return errors::InvalidArgument("Invalid seed type: ",
+ DataTypeString(seed.dtype()));
+ }
+
+ // Scramble the seeds so that the user doesn't need to worry about which
+ // part of the seed needs to be strong.
+ (*out_key)[0] = 0x3ec8f720;
+ (*out_key)[1] = 0x02461e29;
+ (*out_counter)[0] = static_cast<uint32>(seed0);
+ (*out_counter)[1] = static_cast<uint32>(seed0 >> 32);
+ (*out_counter)[2] = static_cast<uint32>(seed1);
+ (*out_counter)[3] = static_cast<uint32>(seed1 >> 32);
+ const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
+ (*out_key)[0] = mix[0];
+ (*out_key)[1] = mix[1];
+ (*out_counter)[0] = (*out_counter)[1] = 0;
+ (*out_counter)[2] = mix[2];
+ (*out_counter)[3] = mix[3];
+ return Status::OK();
+}
+
namespace {
class StatelessRandomOpBase : public OpKernel {
@@ -49,36 +84,9 @@ class StatelessRandomOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
if (shape.num_elements() == 0) return;
- // Grab the two seeds
- uint64 seed0;
- uint64 seed1;
- if (context->input_dtype(1) == DT_INT32) {
- const auto seed = seed_t.flat<int32>();
- seed0 = internal::SubtleMustCopy(seed(0));
- seed1 = internal::SubtleMustCopy(seed(1));
- } else {
- CHECK_EQ(DT_INT64, context->input_dtype(1));
- const auto seed = seed_t.flat<int64>();
- seed0 = internal::SubtleMustCopy(seed(0));
- seed1 = internal::SubtleMustCopy(seed(1));
- }
-
- // Scramble the seeds so that the user doesn't need to worry about which
- // part of the seed needs to be strong.
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
- key[0] = 0x3ec8f720;
- key[1] = 0x02461e29;
- counter[0] = static_cast<uint32>(seed0);
- counter[1] = static_cast<uint32>(seed0 >> 32);
- counter[2] = static_cast<uint32>(seed1);
- counter[3] = static_cast<uint32>(seed1 >> 32);
- const auto mix = random::PhiloxRandom(counter, key)();
- key[0] = mix[0];
- key[1] = mix[1];
- counter[0] = counter[1] = 0;
- counter[2] = mix[2];
- counter[3] = mix[3];
+ OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
// Fill in the random numbers
Fill(context, random::PhiloxRandom(counter, key), output);
@@ -105,8 +113,6 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-} // namespace
-
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniform") \
@@ -176,4 +182,6 @@ TF_CALL_double(REGISTER);
#endif // GOOGLE_CUDA
+} // namespace
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/stateless_random_ops.h b/tensorflow/core/kernels/stateless_random_ops.h
new file mode 100644
index 0000000000..bcd29c4873
--- /dev/null
+++ b/tensorflow/core/kernels/stateless_random_ops.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+
+// Generates a key and counter that can be used to seed a PhiloxRandom,
+// generator, based on the seed value in `seed_t`.
+//
+// REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}.
+// `out_key` and `out_counter` must be non-null.
+Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key,
+ random::PhiloxRandom::ResultType* out_counter);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 553850610a..742709fb18 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -29,7 +29,7 @@ static Status StatelessShape(shape_inference::InferenceContext* context) {
TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
// Set output shape
- shape_inference::ShapeHandle out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
context->set_output(0, out);
return Status::OK();
@@ -54,6 +54,32 @@ REGISTER_STATELESS_OP("StatelessRandomNormal");
// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
+// This op is exposed through contrib/stateless only. The interface may change.
+REGISTER_OP("StatelessMultinomial")
+ .Input("logits: T")
+ .Input("num_samples: int32")
+ .Input("seed: Tseed")
+ .Output("output: output_dtype")
+ .Attr("T: realnumbertype")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .Attr("output_dtype: {int32, int64} = DT_INT64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Check seed shape
+ ShapeHandle seed;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &seed));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim));
+
+ ShapeHandle logits_shape;
+ ShapeHandle unused;
+ DimensionHandle num_samples;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples));
+ c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples));
+ return Status::OK();
+ });
+
#undef REGISTER_STATELESS_OP
} // namespace tensorflow
diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc
index 2d1e9a293e..7c7ba4cef6 100644
--- a/tensorflow/core/util/guarded_philox_random.cc
+++ b/tensorflow/core/util/guarded_philox_random.cc
@@ -43,6 +43,14 @@ void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) {
initialized_ = true;
}
+void GuardedPhiloxRandom::Init(random::PhiloxRandom::ResultType counter,
+ random::PhiloxRandom::Key key) {
+ CHECK(!initialized_);
+ mutex_lock lock(mu_);
+ generator_ = random::PhiloxRandom(counter, key);
+ initialized_ = true;
+}
+
random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) {
CHECK(initialized_);
mutex_lock lock(mu_);
diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h
index 5b94a76777..44970eb949 100644
--- a/tensorflow/core/util/guarded_philox_random.h
+++ b/tensorflow/core/util/guarded_philox_random.h
@@ -49,6 +49,8 @@ class GuardedPhiloxRandom {
// Initialize with given seeds.
void Init(int64 seed, int64 seed2);
+ void Init(random::PhiloxRandom::ResultType counter,
+ random::PhiloxRandom::Key key);
// Reserve a certain number of 128-bit samples.
// This function is thread safe. The returned generator is valid for the