diff options
author | 2017-10-04 03:57:59 -0700 | |
---|---|---|
committer | 2017-10-04 04:02:24 -0700 | |
commit | 727d6270f9d16b4f60ac35039abb161bd037812d (patch) | |
tree | c550cebd6daeff87de55e996083e2c50b7016cd4 /tensorflow/contrib/tensor_forest | |
parent | d016cb020583b1ecbc260c1492e347c2731b1c29 (diff) |
Fix race condition in TensorForest tree traversal.
PiperOrigin-RevId: 170990425
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r-- | tensorflow/contrib/tensor_forest/kernels/model_ops.cc | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 29e0d6af78..b9aad36f3d 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -271,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = - std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0)); } void Compute(OpKernelContext* context) override { @@ -282,8 +279,9 @@ class TraverseTreeV4Op : public OpKernel { const Tensor& sparse_input_values = context->input(3); const Tensor& sparse_input_shape = context->input(4); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); DecisionTreeResource* decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), @@ -291,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel { mutex_lock l(*decision_tree_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_resource); - const int num_data = data_set_->NumItems(); + const int num_data = data_set->NumItems(); Tensor* output_predictions = nullptr; TensorShape output_shape; @@ -306,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data]( - int64 start, int64 end) { + auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, + num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start), + TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start), static_cast<int32>(end), set_leaf_ids, nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, @@ -319,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel { private: tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr<TensorDataSet> data_set_; TensorForestParams param_proto_; }; |