diff options
author | 2016-11-17 15:12:47 -0800 | |
---|---|---|
committer | 2016-11-17 15:26:38 -0800 | |
commit | 7ed1fb97cd9c6c8606264123d6c0222a852a98c9 (patch) | |
tree | 2757f386b6e03c317ef38ab4ed5bbaa325eb2d6a | |
parent | 4daf1fcfbfdc50d304cc2c527736af47035cc167 (diff) |
Extract epoch via its limit_epochs variable, rather than as a passed
parameter.
Change: 139513507
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest.py | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 9d4e97cc7d..096e77b925 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging @@ -137,6 +138,17 @@ class ForestHParams(object): return self +def get_epoch_variable(): + """Returns the epoch variable, or [0] if not defined.""" + # Grab epoch variable defined in + # //third_party/tensorflow/python/training/input.py::limit_epochs + for v in tf_variables.local_variables(): + if 'limit_epochs/epoch' in v.op.name: + return array_ops.reshape(v, [1]) + # TODO(thomaswc): Access epoch from the data feeder. + return [0] + + # A simple container to hold the training variables for a single tree. class TreeTrainingVariables(object): """Stores tf.Variables for training a single random tree. @@ -328,8 +340,11 @@ class RandomForestGraphs(object): return array_ops.concat( 1, [split_data[ind] for ind in self.params.bagged_features[tree_num]]) - def training_graph(self, input_data, input_labels, data_spec=None, - epoch=None, **tree_kwargs): + def training_graph(self, + input_data, + input_labels, + data_spec=None, + **tree_kwargs): """Constructs a TF graph for training a random forest. Args: @@ -338,7 +353,6 @@ class RandomForestGraphs(object): input_data. data_spec: A list of tf.dtype values specifying the original types of each column. - epoch: A tensor or placeholder for the epoch the training data comes from. **tree_kwargs: Keyword arguments passed to each tree's training_graph. Returns: @@ -376,7 +390,6 @@ class RandomForestGraphs(object): tree_graphs.append( self.trees[i].training_graph( tree_data, tree_labels, seed, data_spec=data_spec, - epoch=([0] if epoch is None else epoch), **tree_kwargs)) return control_flow_ops.group(*tree_graphs, name='train') @@ -591,7 +604,6 @@ class RandomTreeGraphs(object): input_labels, random_seed, data_spec, - epoch=None, input_weights=None): """Constructs a TF graph for training a random tree. @@ -604,14 +616,13 @@ class RandomTreeGraphs(object): means use the current time as the seed. data_spec: A list of tf.dtype values specifying the original types of each column. - epoch: A tensor or placeholder for the epoch the training data comes from. input_weights: A float tensor or placeholder holding per-input weights, or None if all inputs are to be weighted equally. Returns: The last op in the random tree training graph. """ - epoch = [0] if epoch is None else epoch + epoch = math_ops.to_int32(get_epoch_variable()) if input_weights is None: input_weights = [] |