diff options
author | 2016-07-18 16:13:19 -0800 | |
---|---|---|
committer | 2016-07-18 17:18:28 -0700 | |
commit | 9c834267160ca4944647cd9bc2bb22dbb298dbd8 (patch) | |
tree | 43e064e3e19969aed0a5b8e02331f72dfacfc1ff | |
parent | e6320fa54433a9c202c908418819dc6b98b0aeca (diff) |
Allow initializing non-partitioned variable with non-Tensor values
Change: 127777937
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 16 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 2 |
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index aefd42f00d..632e196b78 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy import tensorflow as tf from tensorflow.python.ops import control_flow_ops @@ -68,6 +69,21 @@ class VariableScopeTest(tf.test.TestCase): sess.run(tf.initialize_variables([w])) self.assertAllClose(w.eval(), 0.3) + def testInitFromNonTensorValue(self): + with self.test_session() as sess: + v = tf.get_variable("v", initializer=4, dtype=tf.int32) + sess.run(tf.initialize_variables([v])) + self.assertAllClose(v.eval(), 4) + + w = tf.get_variable("w", + initializer=numpy.array([1, 2, 3]), + dtype=tf.int32) + sess.run(tf.initialize_variables([w])) + self.assertAllClose(w.eval(), [1, 2, 3]) + + with self.assertRaises(TypeError): + tf.get_variable("x", initializer={}) + def testVarScopeCachingDevice(self): with self.test_session(): caching_device = "/job:moo" diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 5e16fe4d98..cc1111a355 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -432,7 +432,7 @@ class _VariableStore(object): # Set to true if initializer is a constant. initializing_from_value = False - if initializer is not None and isinstance(initializer, ops.Tensor): + if initializer is not None and not callable(initializer): initializing_from_value = True if shape is not None and initializing_from_value: raise ValueError("If initializer is a constant, do not specify shape.") |