diff options
author | 2017-07-18 12:02:27 -0700 | |
---|---|---|
committer | 2017-07-18 12:06:18 -0700 | |
commit | e87e8199a88938da2084796c29e66a89e8708f2d (patch) | |
tree | 64fcb43fa75a8aaf2fbcc060015c0a01fa4848fa | |
parent | 5bdfcfbcbab3580f51531c9ac832302016775f63 (diff) |
Destroy resources at end of session in TensorForestEstimator to avoid memory leak.
PiperOrigin-RevId: 162383623
3 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index d067b4b779..2eb37d3e53 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -73,8 +74,8 @@ class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook): self._ops = op_dict def end(self, session): - for name, op in self._ops.iteritems(): - logging.info('{0}: {1}'.format(name, session.run(op))) + for name in sorted(self._ops.keys()): + logging.info('{0}: {1}'.format(name, session.run(self._ops[name]))) class TensorForestLossHook(session_run_hook.SessionRunHook): @@ -234,9 +235,19 @@ def get_model_fn(params, logits=logits, scope=head_scope) + # Ops are run in lexigraphical order of their keys. Run the resource + # clean-up op last. + all_handles = graph_builder.get_all_resource_handles() + ops_at_end = { + '9: clean up resources': control_flow_ops.group( + *[resource_variable_ops.destroy_resource_op(handle) + for handle in all_handles])} + if report_feature_importances: - training_hooks.append(TensorForestRunOpAtEndHook( - {'feature_importances': graph_builder.feature_importances()})) + ops_at_end['1: feature_importances'] = ( + graph_builder.feature_importances()) + + training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end)) if early_stopping_rounds: training_hooks.append( diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index b756b0129d..844ac00b0a 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -328,6 +328,9 @@ class RandomForestGraphs(object): return array_ops.concat( [split_data[ind] for ind in self.params.bagged_features[tree_num]], 1) + def get_all_resource_handles(self): + return [] + def training_graph(self, input_data, input_labels, diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py index 8198c228dd..fefc19d3bd 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py @@ -315,6 +315,7 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs): class RandomForestGraphsV4(tensor_forest.RandomForestGraphs): + """Constructs training/inference graphs for TensorForest v4.""" def __init__(self, params, tree_graphs=None, tree_variables_class=None, **kwargs): @@ -324,3 +325,7 @@ class RandomForestGraphsV4(tensor_forest.RandomForestGraphs): params, tree_graphs=tree_graphs or RandomTreeGraphsV4, tree_variables_class=(tree_variables_class or TreeTrainingVariablesV4), **kwargs) + + def get_all_resource_handles(self): + return ([self.variables[i].tree for i in range(len(self.trees))] + + [self.variables[i].stats for i in range(len(self.trees))]) |