aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 12:35:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 12:44:46 -0700
commitcb00181bf7e3417b2ab64756dc0c535b5d1d1332 (patch)
tree14e2981ef9ebcb5441c81fe9dff7b517b2da463f /tensorflow/contrib/tensor_forest
parent4dde158207d075ad65b3b52caa5b4fb90775edbf (diff)
Fixing thread-unsafe access to a class member.
PiperOrigin-RevId: 209980962
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h3
2 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index d43884481a..99c5800391 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -130,7 +130,11 @@ void TensorDataSet::RandomSample(int example,
num_total_features += num_sparse;
}
}
- int rand_feature = rng_->Uniform(num_total_features);
+ int rand_feature = 0;
+ {
+ mutex_lock lock(mu_);
+ rand_feature = rng_->Uniform(num_total_features);
+ }
if (rand_feature < available_features_.size()) { // it's dense.
*feature_id = available_features_[rand_feature];
*type = input_spec_.GetDenseFeatureType(rand_feature);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index 95f75b4d7e..4945b53007 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -25,6 +25,7 @@
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tensorforest {
@@ -120,6 +121,8 @@ class TensorDataSet {
int32 split_sampling_random_seed_;
std::unique_ptr<random::PhiloxRandom> single_rand_;
std::unique_ptr<random::SimplePhilox> rng_;
+ // Mutex for using random number generator.
+ mutable mutex mu_;
};
} // namespace tensorforest
} // namespace tensorflow