aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/client
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-31 08:20:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 08:28:24 -0800
commit955feabb4cb1cddfea8383a1265ef91c5e0a0f2e (patch)
tree5e994aff527afa9ecfdf940c74bcdc17fc3485c6 /tensorflow/contrib/tensor_forest/client
parent67443722b26c3585d860d44e7069d997300a7187 (diff)
Make random_forest_test TSAN-compliant by inserting some dependencies to avoid R/W hazards.
Change: 146121173
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 0fc855db8a..174394d67e 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
@@ -123,23 +124,25 @@ def get_model_fn(params, graph_builder_class, device_assigner,
if keys:
inference[KEYS_NAME] = keys
- training_loss = None
- if (mode == model_fn_lib.ModeKeys.EVAL or
- mode == model_fn_lib.ModeKeys.TRAIN):
- training_loss = graph_builder.training_loss(
- features, labels, name=LOSS_NAME)
-
# labels might be None if we're doing prediction (which brings up the
# question of why we force everything to adhere to a single model_fn).
+ loss_deps = []
training_graph = None
if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
-
training_graph = control_flow_ops.group(
graph_builder.training_graph(
features, labels, input_weights=weights,
num_trainers=num_trainers,
trainer_id=trainer_id),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
+ loss_deps.append(training_graph)
+
+ training_loss = None
+ if (mode == model_fn_lib.ModeKeys.EVAL or
+ mode == model_fn_lib.ModeKeys.TRAIN):
+ with ops.control_dependencies(loss_deps):
+ training_loss = graph_builder.training_loss(
+ features, labels, name=LOSS_NAME)
# Put weights back in
if weights is not None:
features[weights_name] = weights