aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/candidate_sampler_ops.cc
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2017-02-14 22:39:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-14 22:50:18 -0800
commit02703f9525696f4788496745f6756585c1c546a3 (patch)
tree980adb3e762d5d30c8e4b5b73ea0cf1a69242bae /tensorflow/core/kernels/candidate_sampler_ops.cc
parent07b87114d9a9791320373d6bc1b5b717a3f48e48 (diff)
Fix crash in range sampler by adding a range check in the sampler op.
Change: 147562709
Diffstat (limited to 'tensorflow/core/kernels/candidate_sampler_ops.cc')
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index 6aa9059dc7..9e8b122801 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -47,6 +47,12 @@ class BaseCandidateSamplerOp : public OpKernel {
OP_REQUIRES(context, true_classes.dim_size(1) == num_true_,
errors::InvalidArgument("true_classes must have "
"num_true columns"));
+ CHECK(sampler_) << "CandidateSamplerOp did not set sampler_";
+
+ if (unique_) {
+ OP_REQUIRES(context, num_sampled_ <= sampler_->range(),
+ errors::InvalidArgument("Sampler's range is too small."));
+ }
// Output candidates and expected_count.
Tensor* out_sampled_candidates = nullptr;
@@ -73,8 +79,6 @@ class BaseCandidateSamplerOp : public OpKernel {
gtl::MutableArraySlice<float> sampled_expected_count(
out_sampled_expected_count->vec<float>().data(), num_sampled_);
- CHECK(sampler_) << "CandidateSamplerOp did not set sampler_";
-
// Approximately conservatively estimate the number of samples required.
// In cases where rejection sampling is used we may occasionally use more
// samples than expected, which will result in reused random bits.