aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-17 15:12:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-17 15:26:38 -0800
commit7ed1fb97cd9c6c8606264123d6c0222a852a98c9 (patch)
tree2757f386b6e03c317ef38ab4ed5bbaa325eb2d6a
parent4daf1fcfbfdc50d304cc2c527736af47035cc167 (diff)
Extract epoch via its limit_epochs variable, rather than as a passed
parameter. Change: 139513507
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 9d4e97cc7d..096e77b925 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@@ -137,6 +138,17 @@ class ForestHParams(object):
return self
+def get_epoch_variable():
+ """Returns the epoch variable, or [0] if not defined."""
+ # Grab epoch variable defined in
+ # //third_party/tensorflow/python/training/input.py::limit_epochs
+ for v in tf_variables.local_variables():
+ if 'limit_epochs/epoch' in v.op.name:
+ return array_ops.reshape(v, [1])
+ # TODO(thomaswc): Access epoch from the data feeder.
+ return [0]
+
+
# A simple container to hold the training variables for a single tree.
class TreeTrainingVariables(object):
"""Stores tf.Variables for training a single random tree.
@@ -328,8 +340,11 @@ class RandomForestGraphs(object):
return array_ops.concat(
1, [split_data[ind] for ind in self.params.bagged_features[tree_num]])
- def training_graph(self, input_data, input_labels, data_spec=None,
- epoch=None, **tree_kwargs):
+ def training_graph(self,
+ input_data,
+ input_labels,
+ data_spec=None,
+ **tree_kwargs):
"""Constructs a TF graph for training a random forest.
Args:
@@ -338,7 +353,6 @@ class RandomForestGraphs(object):
input_data.
data_spec: A list of tf.dtype values specifying the original types of
each column.
- epoch: A tensor or placeholder for the epoch the training data comes from.
**tree_kwargs: Keyword arguments passed to each tree's training_graph.
Returns:
@@ -376,7 +390,6 @@ class RandomForestGraphs(object):
tree_graphs.append(
self.trees[i].training_graph(
tree_data, tree_labels, seed, data_spec=data_spec,
- epoch=([0] if epoch is None else epoch),
**tree_kwargs))
return control_flow_ops.group(*tree_graphs, name='train')
@@ -591,7 +604,6 @@ class RandomTreeGraphs(object):
input_labels,
random_seed,
data_spec,
- epoch=None,
input_weights=None):
"""Constructs a TF graph for training a random tree.
@@ -604,14 +616,13 @@ class RandomTreeGraphs(object):
means use the current time as the seed.
data_spec: A list of tf.dtype values specifying the original types of
each column.
- epoch: A tensor or placeholder for the epoch the training data comes from.
input_weights: A float tensor or placeholder holding per-input weights,
or None if all inputs are to be weighted equally.
Returns:
The last op in the random tree training graph.
"""
- epoch = [0] if epoch is None else epoch
+ epoch = math_ops.to_int32(get_epoch_variable())
if input_weights is None:
input_weights = []