diff options
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index cd99b7ffd6..ab7b3c1761 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -34,7 +34,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging @@ -164,8 +163,10 @@ class TreeTrainingVariables(object): name=self.get_tree_name('end_of_tree', tree_num), dtype=dtypes.int32, initializer=constant_op.constant([1])) - self.start_epoch = tf_variables.Variable( - [0] * (params.max_nodes), name='start_epoch') + self.start_epoch = variable_scope.get_variable( + name=self.get_tree_name('start_epoch', tree_num), + dtype=dtypes.int32, shape=[params.max_nodes], + initializer=init_ops.constant_initializer(0)) if training: self.node_to_accumulator_map = variable_scope.get_variable( |