aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-26 12:52:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 12:56:43 -0800
commitf6a53e7abd54afdff4d1377535d61dbc1efd174c (patch)
tree2f0ccee95954074f47e1a01412971e150b6e9bd1 /tensorflow/contrib/boosted_trees/python
parent12cfeb2c5291b1d2af55bf0905374043be599c5a (diff)
Make the graph generation of TFBT deterministic.
PiperOrigin-RevId: 183431139
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python')
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py7
2 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
index b281a4c6d1..7a5f329b7a 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
@@ -81,32 +81,32 @@ def _scheduled_stamp_resource_op_runner(batch, stamp):
if not batch:
return
arg_keys = set(batch[0].args.keys())
- grouped_args = collections.defaultdict(list)
+ grouped_args = collections.OrderedDict()
resource_handles = []
# Check that the set of arguments is the same across all the scheduled ops.
for op in batch:
if set(op.args.keys()) != arg_keys:
raise ValueError("Mismatching arguments: %s, %s.", op.args, arg_keys)
for key in arg_keys:
- grouped_args[key].append(op.args[key])
+ grouped_args.setdefault(key, []).append(op.args[key])
resource_handles.append(op.resource_handle)
# Move all the inputs to the op device in one RPC.
- grouped_args = {
- k: _move_tensors(v, resource_handles[0].device)
- for k, v in grouped_args.items()
- }
+ grouped_args = collections.OrderedDict(
+ (k, _move_tensors(v, resource_handles[0].device))
+ for k, v in sorted(grouped_args.items()))
with ops.device(resource_handles[0].device):
return batch[0].op(resource_handles, stamp, **grouped_args)
def run_handler_scheduled_ops(per_handler_ops, stamp, worker_device):
"""Given a dictionary of ops for each handler, runs them in batch."""
- batched_ops = collections.defaultdict(list)
+ batched_ops = collections.OrderedDict()
# Group the ops by their batching_key. Ops that share the same batching key
# can be executed together.
- for handler in sorted(per_handler_ops.keys()):
+ for handler in per_handler_ops.keys():
for op in per_handler_ops[handler]:
- batched_ops[(op.batching_key(), op.batch_runner_fn())].append(op)
+ key = (op.batching_key(), op.batch_runner_fn())
+ batched_ops.setdefault(key, []).append(op)
op_results = {}
for batch in batched_ops.values():
# Run each of the batched ops using its runner.
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b95956dae2..f0b66dcbbe 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import copy
from tensorflow.contrib import learn
@@ -163,7 +164,7 @@ def extract_features(features, feature_columns):
scope = "gbdt"
with variable_scope.variable_scope(scope):
feature_columns = list(feature_columns)
- transformed_features = {}
+ transformed_features = collections.OrderedDict()
for fc in feature_columns:
# pylint: disable=protected-access
if isinstance(fc, feature_column_lib._EmbeddingColumn):
@@ -681,13 +682,13 @@ class GradientBoostedDecisionTreeModel(object):
control_flow_ops.no_op))
# Update handler stats.
- handler_reads = {}
+ handler_reads = collections.OrderedDict()
for handler in handlers:
handler_reads[handler] = handler.scheduled_reads()
handler_results = batch_ops_utils.run_handler_scheduled_ops(
handler_reads, ensemble_stamp, worker_device)
- per_handler_updates = {}
+ per_handler_updates = collections.OrderedDict()
# Two values per handler. First one is if the handler is active for the
# current layer. The second one is if the handler is going to be active
# for the next layer.