aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-21 05:26:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-21 06:44:54 -0700
commit3a5b605d8eb745e2b9d0b2751b9077c3dce9e2dc (patch)
tree3696cb8a9060701f954b350fbb656b765c1dd34c
parent00526b6f8682cc535d8fbd69cc00a396e6549825 (diff)
Remove potentially large python constant from TensorForest graph.
Change: 136825477
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index cd99b7ffd6..ab7b3c1761 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -34,7 +34,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@@ -164,8 +163,10 @@ class TreeTrainingVariables(object):
name=self.get_tree_name('end_of_tree', tree_num),
dtype=dtypes.int32,
initializer=constant_op.constant([1]))
- self.start_epoch = tf_variables.Variable(
- [0] * (params.max_nodes), name='start_epoch')
+ self.start_epoch = variable_scope.get_variable(
+ name=self.get_tree_name('start_epoch', tree_num),
+ dtype=dtypes.int32, shape=[params.max_nodes],
+ initializer=init_ops.constant_initializer(0))
if training:
self.node_to_accumulator_map = variable_scope.get_variable(