aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-04 03:57:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 04:02:24 -0700
commit727d6270f9d16b4f60ac35039abb161bd037812d (patch)
treec550cebd6daeff87de55e996083e2c50b7016cd4 /tensorflow/contrib/tensor_forest
parentd016cb020583b1ecbc260c1492e347c2731b1c29 (diff)
Fix race condition in TensorForest tree traversal.
PiperOrigin-RevId: 170990425
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops.cc17
1 files changed, 7 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
index 29e0d6af78..b9aad36f3d 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -271,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel {
string serialized_proto;
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
input_spec_.ParseFromString(serialized_proto);
-
- data_set_ =
- std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
}
void Compute(OpKernelContext* context) override {
@@ -282,8 +279,9 @@ class TraverseTreeV4Op : public OpKernel {
const Tensor& sparse_input_values = context->input(3);
const Tensor& sparse_input_shape = context->input(4);
- data_set_->set_input_tensors(input_data, sparse_input_indices,
- sparse_input_values, sparse_input_shape);
+ std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
+ data_set->set_input_tensors(input_data, sparse_input_indices,
+ sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
@@ -291,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel {
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
- const int num_data = data_set_->NumItems();
+ const int num_data = data_set->NumItems();
Tensor* output_predictions = nullptr;
TensorShape output_shape;
@@ -306,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
- auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
- int64 start, int64 end) {
+ auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource,
+ num_data](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
- TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
+ TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
static_cast<int32>(end), set_leaf_ids, nullptr);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
@@ -319,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel {
private:
tensorforest::TensorForestDataSpec input_spec_;
- std::unique_ptr<TensorDataSet> data_set_;
TensorForestParams param_proto_;
};