diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2017-02-14 22:39:15 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-14 22:50:18 -0800 |
commit | 02703f9525696f4788496745f6756585c1c546a3 (patch) | |
tree | 980adb3e762d5d30c8e4b5b73ea0cf1a69242bae /tensorflow/core/kernels/candidate_sampler_ops.cc | |
parent | 07b87114d9a9791320373d6bc1b5b717a3f48e48 (diff) |
Fix crash in range sampler by adding a range check in the sampler op.
Change: 147562709
Diffstat (limited to 'tensorflow/core/kernels/candidate_sampler_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/candidate_sampler_ops.cc | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc index 6aa9059dc7..9e8b122801 100644 --- a/tensorflow/core/kernels/candidate_sampler_ops.cc +++ b/tensorflow/core/kernels/candidate_sampler_ops.cc @@ -47,6 +47,12 @@ class BaseCandidateSamplerOp : public OpKernel { OP_REQUIRES(context, true_classes.dim_size(1) == num_true_, errors::InvalidArgument("true_classes must have " "num_true columns")); + CHECK(sampler_) << "CandidateSamplerOp did not set sampler_"; + + if (unique_) { + OP_REQUIRES(context, num_sampled_ <= sampler_->range(), + errors::InvalidArgument("Sampler's range is too small.")); + } // Output candidates and expected_count. Tensor* out_sampled_candidates = nullptr; @@ -73,8 +79,6 @@ class BaseCandidateSamplerOp : public OpKernel { gtl::MutableArraySlice<float> sampled_expected_count( out_sampled_expected_count->vec<float>().data(), num_sampled_); - CHECK(sampler_) << "CandidateSamplerOp did not set sampler_"; - // Approximately conservatively estimate the number of samples required. // In cases where rejection sampling is used we may occasionally use more // samples than expected, which will result in reused random bits. |