diff options
author | 2017-01-31 08:20:54 -0800 | |
---|---|---|
committer | 2017-01-31 08:28:24 -0800 | |
commit | 955feabb4cb1cddfea8383a1265ef91c5e0a0f2e (patch) | |
tree | 5e994aff527afa9ecfdf940c74bcdc17fc3485c6 /tensorflow/contrib/tensor_forest/client | |
parent | 67443722b26c3585d860d44e7069d997300a7187 (diff) |
Make random_forest_test TSAN-compliant by inserting some dependencies to avoid R/W hazards.
Change: 146121173
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client')
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 0fc855db8a..174394d67e 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -29,6 +29,7 @@ from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -123,23 +124,25 @@ def get_model_fn(params, graph_builder_class, device_assigner, if keys: inference[KEYS_NAME] = keys - training_loss = None - if (mode == model_fn_lib.ModeKeys.EVAL or - mode == model_fn_lib.ModeKeys.TRAIN): - training_loss = graph_builder.training_loss( - features, labels, name=LOSS_NAME) - # labels might be None if we're doing prediction (which brings up the # question of why we force everything to adhere to a single model_fn). + loss_deps = [] training_graph = None if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN: - training_graph = control_flow_ops.group( graph_builder.training_graph( features, labels, input_weights=weights, num_trainers=num_trainers, trainer_id=trainer_id), state_ops.assign_add(contrib_framework.get_global_step(), 1)) + loss_deps.append(training_graph) + + training_loss = None + if (mode == model_fn_lib.ModeKeys.EVAL or + mode == model_fn_lib.ModeKeys.TRAIN): + with ops.control_dependencies(loss_deps): + training_loss = graph_builder.training_loss( + features, labels, name=LOSS_NAME) # Put weights back in if weights is not None: features[weights_name] = weights |