aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-18 12:02:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 12:06:18 -0700
commite87e8199a88938da2084796c29e66a89e8708f2d (patch)
tree64fcb43fa75a8aaf2fbcc060015c0a01fa4848fa
parent5bdfcfbcbab3580f51531c9ac832302016775f63 (diff)
Destroy resources at end of session in TensorForestEstimator to avoid memory leak.
PiperOrigin-RevId: 162383623
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py19
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py3
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py5
3 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index d067b4b779..2eb37d3e53 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -73,8 +74,8 @@ class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
self._ops = op_dict
def end(self, session):
- for name, op in self._ops.iteritems():
- logging.info('{0}: {1}'.format(name, session.run(op)))
+ for name in sorted(self._ops.keys()):
+ logging.info('{0}: {1}'.format(name, session.run(self._ops[name])))
class TensorForestLossHook(session_run_hook.SessionRunHook):
@@ -234,9 +235,19 @@ def get_model_fn(params,
logits=logits,
scope=head_scope)
+ # Ops are run in lexigraphical order of their keys. Run the resource
+ # clean-up op last.
+ all_handles = graph_builder.get_all_resource_handles()
+ ops_at_end = {
+ '9: clean up resources': control_flow_ops.group(
+ *[resource_variable_ops.destroy_resource_op(handle)
+ for handle in all_handles])}
+
if report_feature_importances:
- training_hooks.append(TensorForestRunOpAtEndHook(
- {'feature_importances': graph_builder.feature_importances()}))
+ ops_at_end['1: feature_importances'] = (
+ graph_builder.feature_importances())
+
+ training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
if early_stopping_rounds:
training_hooks.append(
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index b756b0129d..844ac00b0a 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -328,6 +328,9 @@ class RandomForestGraphs(object):
return array_ops.concat(
[split_data[ind] for ind in self.params.bagged_features[tree_num]], 1)
+ def get_all_resource_handles(self):
+ return []
+
def training_graph(self,
input_data,
input_labels,
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
index 8198c228dd..fefc19d3bd 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
@@ -315,6 +315,7 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
class RandomForestGraphsV4(tensor_forest.RandomForestGraphs):
+ """Constructs training/inference graphs for TensorForest v4."""
def __init__(self, params, tree_graphs=None, tree_variables_class=None,
**kwargs):
@@ -324,3 +325,7 @@ class RandomForestGraphsV4(tensor_forest.RandomForestGraphs):
params, tree_graphs=tree_graphs or RandomTreeGraphsV4,
tree_variables_class=(tree_variables_class or TreeTrainingVariablesV4),
**kwargs)
+
+ def get_all_resource_handles(self):
+ return ([self.variables[i].tree for i in range(len(self.trees))] +
+ [self.variables[i].stats for i in range(len(self.trees))])