aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index cd99b7ffd6..ab7b3c1761 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -34,7 +34,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@@ -164,8 +163,10 @@ class TreeTrainingVariables(object):
name=self.get_tree_name('end_of_tree', tree_num),
dtype=dtypes.int32,
initializer=constant_op.constant([1]))
- self.start_epoch = tf_variables.Variable(
- [0] * (params.max_nodes), name='start_epoch')
+ self.start_epoch = variable_scope.get_variable(
+ name=self.get_tree_name('start_epoch', tree_num),
+ dtype=dtypes.int32, shape=[params.max_nodes],
+ initializer=init_ops.constant_initializer(0))
if training:
self.node_to_accumulator_map = variable_scope.get_variable(