aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Wei Ho <weiho@google.com>2016-07-18 16:13:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-18 17:18:28 -0700
commit9c834267160ca4944647cd9bc2bb22dbb298dbd8 (patch)
tree43e064e3e19969aed0a5b8e02331f72dfacfc1ff
parente6320fa54433a9c202c908418819dc6b98b0aeca (diff)
Allow initializing non-partitioned variable with non-Tensor values
Change: 127777937
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py16
-rw-r--r--tensorflow/python/ops/variable_scope.py2
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index aefd42f00d..632e196b78 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
@@ -68,6 +69,21 @@ class VariableScopeTest(tf.test.TestCase):
sess.run(tf.initialize_variables([w]))
self.assertAllClose(w.eval(), 0.3)
+ def testInitFromNonTensorValue(self):
+ with self.test_session() as sess:
+ v = tf.get_variable("v", initializer=4, dtype=tf.int32)
+ sess.run(tf.initialize_variables([v]))
+ self.assertAllClose(v.eval(), 4)
+
+ w = tf.get_variable("w",
+ initializer=numpy.array([1, 2, 3]),
+ dtype=tf.int32)
+ sess.run(tf.initialize_variables([w]))
+ self.assertAllClose(w.eval(), [1, 2, 3])
+
+ with self.assertRaises(TypeError):
+ tf.get_variable("x", initializer={})
+
def testVarScopeCachingDevice(self):
with self.test_session():
caching_device = "/job:moo"
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 5e16fe4d98..cc1111a355 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -432,7 +432,7 @@ class _VariableStore(object):
# Set to true if initializer is a constant.
initializing_from_value = False
- if initializer is not None and isinstance(initializer, ops.Tensor):
+ if initializer is not None and not callable(initializer):
initializing_from_value = True
if shape is not None and initializing_from_value:
raise ValueError("If initializer is a constant, do not specify shape.")