diff options
Diffstat (limited to 'tensorflow/core/kernels/candidate_sampler_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/candidate_sampler_ops.cc | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc new file mode 100644 index 0000000000..cd5fde37a6 --- /dev/null +++ b/tensorflow/core/kernels/candidate_sampler_ops.cc @@ -0,0 +1,243 @@ +// See docs in ../ops/candidate_sampling_ops.cc. + +#define EIGEN_USE_THREADS + +#include <cfloat> +#include <unordered_map> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/range_sampler.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +class BaseCandidateSamplerOp : public OpKernel { + public: + explicit BaseCandidateSamplerOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("num_sampled", &num_sampled_)); + OP_REQUIRES_OK(context, context->GetAttr("num_true", &num_true_)); + OP_REQUIRES_OK(context, context->GetAttr("unique", &unique_)); + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& true_classes = context->input(0); + OP_REQUIRES(context, true_classes.dims() == 2, + errors::InvalidArgument("true_classes must be a matrix")); + const int32 batch_size = true_classes.dim_size(0); + OP_REQUIRES(context, true_classes.dim_size(1) == num_true_, + errors::InvalidArgument("true_classes must have " + "num_true columns")); + + // Output candidates and expected_count. + Tensor* out_sampled_candidates = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_sampled_}), + &out_sampled_candidates)); + + Tensor* out_true_expected_count = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({batch_size, num_true_}), + &out_true_expected_count)); + Tensor* out_sampled_expected_count = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(2, TensorShape({num_sampled_}), + &out_sampled_expected_count)); + + gtl::ArraySlice<int64> true_candidate(true_classes.matrix<int64>().data(), + batch_size * num_true_); + gtl::MutableArraySlice<int64> sampled_candidate( + out_sampled_candidates->vec<int64>().data(), num_sampled_); + gtl::MutableArraySlice<float> true_expected_count( + out_true_expected_count->matrix<float>().data(), + batch_size * num_true_); + gtl::MutableArraySlice<float> sampled_expected_count( + out_sampled_expected_count->vec<float>().data(), num_sampled_); + + CHECK(sampler_) << "CandidateSamplerOp did not set sampler_"; + + // Approximately conservatively estimate the number of samples required. + // In cases where rejection sampling is used we may occasionally use more + // samples than expected, which will result in reused random bits. + const int64 samples32 = 2048 * num_sampled_; + + // Pick sampled candidates. + auto local_gen = generator_.ReserveSamples32(samples32); + random::SimplePhilox random(&local_gen); + sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate, + &sampled_expected_count, + true_candidate, &true_expected_count); + + if (sampler_->NeedsUpdates()) { + sampler_->Update(true_candidate); + } + } + + protected: + void set_sampler(RangeSampler* sampler) { sampler_.reset(sampler); } + + private: + int32 num_true_; + int32 num_sampled_; + bool unique_; + std::unique_ptr<RangeSampler> sampler_; + GuardedPhiloxRandom generator_; +}; + +template <class RangeSamplerType> +class SimpleCandidateSamplerOp : public BaseCandidateSamplerOp { + public: + explicit SimpleCandidateSamplerOp(OpKernelConstruction* context) + : BaseCandidateSamplerOp(context) { + int64 range_max; + OP_REQUIRES_OK(context, context->GetAttr("range_max", &range_max)); + set_sampler(new RangeSamplerType(range_max)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("UniformCandidateSampler").Device(DEVICE_CPU), + SimpleCandidateSamplerOp<UniformSampler>); + +REGISTER_KERNEL_BUILDER(Name("LogUniformCandidateSampler").Device(DEVICE_CPU), + SimpleCandidateSamplerOp<LogUniformSampler>); + +REGISTER_KERNEL_BUILDER(Name("LearnedUnigramCandidateSampler") + .Device(DEVICE_CPU), + SimpleCandidateSamplerOp<UnigramSampler>); + +REGISTER_KERNEL_BUILDER(Name("ThreadUnsafeUnigramCandidateSampler") + .Device(DEVICE_CPU), + SimpleCandidateSamplerOp<ThreadUnsafeUnigramSampler>); + +class AllCandidateSamplerOp : public BaseCandidateSamplerOp { + public: + explicit AllCandidateSamplerOp(OpKernelConstruction* context) + : BaseCandidateSamplerOp(context) { + int64 range_max; + OP_REQUIRES_OK(context, context->GetAttr("num_sampled", &range_max)); + set_sampler(new AllSampler(range_max)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("AllCandidateSampler").Device(DEVICE_CPU), + AllCandidateSamplerOp); + +class FixedUnigramCandidateSamplerOp : public BaseCandidateSamplerOp { + public: + explicit FixedUnigramCandidateSamplerOp(OpKernelConstruction* context) + : BaseCandidateSamplerOp(context) { + int64 range_max; + OP_REQUIRES_OK(context, context->GetAttr("range_max", &range_max)); + string vocab_file; + OP_REQUIRES_OK(context, context->GetAttr("vocab_file", &vocab_file)); + std::vector<float> unigrams; + OP_REQUIRES_OK(context, context->GetAttr("unigrams", &unigrams)); + OP_REQUIRES( + context, !vocab_file.empty() || !unigrams.empty(), + errors::InvalidArgument("Must provide either vocab_file or unigrams.")); + OP_REQUIRES(context, vocab_file.empty() || unigrams.empty(), + errors::InvalidArgument( + "Must only provide one of vocab_file and unigrams.")); + float distortion; + OP_REQUIRES_OK(context, context->GetAttr("distortion", &distortion)); + int64 num_reserved_ids; + OP_REQUIRES_OK(context, + context->GetAttr("num_reserved_ids", &num_reserved_ids)); + int64 num_shards; + OP_REQUIRES_OK(context, context->GetAttr("num_shards", &num_shards)); + int64 shard; + OP_REQUIRES_OK(context, context->GetAttr("shard", &shard)); + + if (!vocab_file.empty()) { + set_sampler(new FixedUnigramSampler(context->env(), range_max, vocab_file, + distortion, num_reserved_ids, + num_shards, shard)); + } else { + set_sampler(new FixedUnigramSampler(range_max, unigrams, distortion, + num_reserved_ids, num_shards, shard)); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("FixedUnigramCandidateSampler").Device(DEVICE_CPU), + FixedUnigramCandidateSamplerOp); + +class ComputeAccidentalHitsOp : public OpKernel { + public: + explicit ComputeAccidentalHitsOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("num_true", &num_true_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& in_true_candidates = context->input(0); + TensorShape in_true_candidates_shape = in_true_candidates.shape(); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) && + in_true_candidates_shape.dim_size(1) == num_true_, + errors::InvalidArgument( + "true_candidates must be a batch_size * num_true matrix")); + + const int64 batch_size = in_true_candidates_shape.dim_size(0); + + const Tensor& in_sampled_candidates = context->input(1); + OP_REQUIRES(context, + TensorShapeUtils::IsVector(in_sampled_candidates.shape()), + errors::InvalidArgument( + "sampled_candidates must be a vector, which is typically " + "an output from CandidateSampler")); + + std::unordered_map<int64, int> sampled_candidate_to_pos; + for (int64 i = 0; i < in_sampled_candidates.dim_size(0); ++i) { + sampled_candidate_to_pos[in_sampled_candidates.vec<int64>()(i)] = i; + } + + // Produce output in the same format as UnpackSparseFeatures. + std::vector<int> indices; + std::vector<int64> ids; + std::vector<float> weights; + + for (int64 i = 0; i < batch_size; ++i) { + for (int64 j = 0; j < num_true_; ++j) { + const int64 true_candidate = in_true_candidates.matrix<int64>()(i, j); + const auto look = sampled_candidate_to_pos.find(true_candidate); + if (look != sampled_candidate_to_pos.end()) { + indices.push_back(i); + ids.push_back(look->second); + weights.push_back(-FLT_MAX); + } + } + } + + Tensor* out_indices = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 0, TensorShape({static_cast<int>(indices.size())}), &out_indices)); + Tensor* out_ids = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output( + 1, TensorShape({static_cast<int>(ids.size())}), &out_ids)); + Tensor* out_weights = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 2, TensorShape({static_cast<int>(weights.size())}), &out_weights)); + + for (size_t i = 0; i < indices.size(); ++i) { + out_indices->vec<int32>()(i) = indices[i]; + out_ids->vec<int64>()(i) = ids[i]; + out_weights->vec<float>()(i) = weights[i]; + } + } + + private: + int64 num_true_; +}; + +REGISTER_KERNEL_BUILDER(Name("ComputeAccidentalHits").Device(DEVICE_CPU), + ComputeAccidentalHitsOp); + +} // namespace tensorflow |