aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Lukasz Kaiser <lukaszkaiser@google.com>2016-06-15 19:18:52 -0800
committerGravatar Martin Wicke <wicke@google.com>2016-06-19 14:07:25 -0700
commit1bea99abc531bf2ce47c1f6767f4796be7168f02 (patch)
tree5f9ab00562059a43766304f90f1f47ab1971ab0f
parente13cd0530218e3f8c963b94967c924a8abe3650c (diff)
Use only op_scope, not variable_op_scope, in functional ops since they do not
create variables. Also add missing output_size in EmbeddingWrapper (#2852). Change: 125022470
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py93
-rw-r--r--tensorflow/python/kernel_tests/rnn_cell_test.py6
-rw-r--r--tensorflow/python/ops/functional_ops.py45
-rw-r--r--tensorflow/python/ops/rnn.py2
-rw-r--r--tensorflow/python/ops/rnn_cell.py4
5 files changed, 138 insertions, 12 deletions
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index ec20c9e18f..d0fc17c059 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -23,6 +23,15 @@ import numpy as np
import tensorflow as tf
+def simple_scoped_fn(a, x):
+ """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
+ with tf.variable_scope("body"):
+ # Dummy variable, just to check that scoping works as intended.
+ two = tf.get_variable("two", [], dtype=tf.int32,
+ initializer=tf.constant_initializer(2))
+ return tf.mul(tf.add(a, x), two)
+
+
class FunctionalOpsTest(tf.test.TestCase):
def testFoldl_Simple(self):
@@ -36,6 +45,24 @@ class FunctionalOpsTest(tf.test.TestCase):
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(880, r.eval())
+ def testFoldl_Scoped(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root") as varscope:
+ elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+
+ r = tf.foldl(simple_scoped_fn, elems)
+ # Check that we have the one variable we asked for here.
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
+ sess.run([tf.initialize_all_variables()])
+ self.assertAllEqual(208, r.eval())
+
+ # Now let's reuse our single variable.
+ varscope.reuse_variables()
+ r = tf.foldl(simple_scoped_fn, elems, initializer=10)
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertAllEqual(880, r.eval())
+
def testFoldr_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -47,6 +74,24 @@ class FunctionalOpsTest(tf.test.TestCase):
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(1282, r.eval())
+ def testFoldr_Scoped(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root") as varscope:
+ elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+
+ r = tf.foldr(simple_scoped_fn, elems)
+ # Check that we have the one variable we asked for here.
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
+ sess.run([tf.initialize_all_variables()])
+ self.assertAllEqual(450, r.eval())
+
+ # Now let's reuse our single variable.
+ varscope.reuse_variables()
+ r = tf.foldr(simple_scoped_fn, elems, initializer=10)
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertAllEqual(1282, r.eval())
+
def testFold_Grad(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
@@ -69,6 +114,34 @@ class FunctionalOpsTest(tf.test.TestCase):
r = tf.map_fn(lambda x: tf.mul(tf.add(x, 3), 2), elems)
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
+ def testMap_Scoped(self):
+ with self.test_session() as sess:
+
+ def double_scoped(x):
+ """2x with a dummy 2 that is scoped."""
+ with tf.variable_scope("body"):
+ # Dummy variable, just to check that scoping works as intended.
+ two = tf.get_variable("two", [], dtype=tf.int32,
+ initializer=tf.constant_initializer(2))
+ return tf.mul(x, two)
+
+ with tf.variable_scope("root") as varscope:
+ elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ doubles = np.array([2*x for x in [1, 2, 3, 4, 5, 6]])
+
+ r = tf.map_fn(double_scoped, elems)
+ # Check that we have the one variable we asked for here.
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
+ sess.run([tf.initialize_all_variables()])
+ self.assertAllEqual(doubles, r.eval())
+
+ # Now let's reuse our single variable.
+ varscope.reuse_variables()
+ r = tf.map_fn(double_scoped, elems)
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertAllEqual(doubles, r.eval())
+
def testMap_SimpleNotTensor(self):
with self.test_session():
nums = [1, 2, 3, 4, 5, 6]
@@ -87,6 +160,26 @@ class FunctionalOpsTest(tf.test.TestCase):
lambda a, x: tf.mul(a, x), elems, initializer=v)
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
+ def testScan_Scoped(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root") as varscope:
+ elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+
+ r = tf.scan(simple_scoped_fn, elems)
+ # Check that we have the one variable we asked for here.
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
+ sess.run([tf.initialize_all_variables()])
+ results = np.array([1, 6, 18, 44, 98, 208])
+ self.assertAllEqual(results, r.eval())
+
+ # Now let's reuse our single variable.
+ varscope.reuse_variables()
+ r = tf.scan(simple_scoped_fn, elems, initializer=2)
+ self.assertEqual(len(tf.trainable_variables()), 1)
+ results = np.array([6, 16, 38, 84, 178, 368])
+ self.assertAllEqual(results, r.eval())
+
def testScan_Control(self):
with self.test_session() as sess:
s = tf.placeholder(tf.float32, shape=[None])
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 9bfe8cee91..4543aaade0 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -251,9 +251,11 @@ class RNNCellTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 1], dtype=tf.int32)
m = tf.zeros([1, 2])
- g, new_m = tf.nn.rnn_cell.EmbeddingWrapper(
+ embedding_cell = tf.nn.rnn_cell.EmbeddingWrapper(
tf.nn.rnn_cell.GRUCell(2),
- embedding_classes=3, embedding_size=2)(x, m)
+ embedding_classes=3, embedding_size=2)
+ self.assertEqual(embedding_cell.output_size, 2)
+ g, new_m = embedding_cell(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1]]),
m.name: np.array([[0.1, 0.1]])})
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 4fd5caa902..6986c9cf26 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -87,12 +87,15 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
- # TODO(ebrevdo): Change to using colocate_with here and in other methods.
- with vs.variable_op_scope([elems], name, "foldl") as varscope:
- # Any get_variable calls fn will cache the first call locally
+ with ops.op_scope([elems], name, "foldl"):
+ # Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
+ varscope = vs.get_variable_scope()
+ varscope_caching_device_was_none = False
if varscope.caching_device is None:
+ # TODO(ebrevdo): Change to using colocate_with here and in other methods.
varscope.set_caching_device(lambda op: op.device)
+ varscope_caching_device_was_none = True
# Convert elems to tensor array.
elems = ops.convert_to_tensor(elems, name="elems")
@@ -117,6 +120,9 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
+
+ if varscope_caching_device_was_none:
+ varscope.set_caching_device(None)
return r_a
@@ -161,11 +167,15 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
- with vs.variable_op_scope([elems], name, "foldr") as varscope:
- # Any get_variable calls fn will cache the first call locally
+ with ops.op_scope([elems], name, "foldr"):
+ # Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
+ varscope = vs.get_variable_scope()
+ varscope_caching_device_was_none = False
if varscope.caching_device is None:
+ # TODO(ebrevdo): Change to using colocate_with here and in other methods.
varscope.set_caching_device(lambda op: op.device)
+ varscope_caching_device_was_none = True
# Convert elems to tensor array.
elems = ops.convert_to_tensor(elems, name="elems")
@@ -190,6 +200,9 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
+
+ if varscope_caching_device_was_none:
+ varscope.set_caching_device(None)
return r_a
@@ -232,11 +245,15 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
- with vs.variable_op_scope([elems], name, "map") as varscope:
- # Any get_variable calls fn will cache the first call locally
+ with ops.op_scope([elems], name, "map"):
+ # Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
+ varscope = vs.get_variable_scope()
+ varscope_caching_device_was_none = False
if varscope.caching_device is None:
+ # TODO(ebrevdo): Change to using colocate_with here and in other methods.
varscope.set_caching_device(lambda op: op.device)
+ varscope_caching_device_was_none = True
elems = ops.convert_to_tensor(elems, name="elems")
dtype = dtype if dtype else elems.dtype
@@ -263,6 +280,9 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
result = r_a.pack()
result.set_shape(elems.get_shape().with_rank_at_least(1)[0:1].concatenate(
result.get_shape()[1:]))
+
+ if varscope_caching_device_was_none:
+ varscope.set_caching_device(None)
return result
@@ -307,11 +327,15 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
- with vs.variable_op_scope([elems], name, "scan") as varscope:
- # Any get_variable calls fn will cache the first call locally
+ with ops.op_scope([elems], name, "scan"):
+ # Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
+ varscope = vs.get_variable_scope()
+ varscope_caching_device_was_none = False
if varscope.caching_device is None:
+ # TODO(ebrevdo): Change to using colocate_with here and in other methods.
varscope.set_caching_device(lambda op: op.device)
+ varscope_caching_device_was_none = True
# Convert elems to tensor array.
elems = ops.convert_to_tensor(elems, name="elems")
@@ -346,6 +370,9 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
result = r_a.pack()
result.set_shape(elems.get_shape().with_rank_at_least(1)[0:1].concatenate(
result.get_shape()[1:]))
+
+ if varscope_caching_device_was_none:
+ varscope.set_caching_device(None)
return result
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 1c967198de..c65a4bc23d 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -140,7 +140,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
max_sequence_length = math_ops.reduce_max(sequence_length)
for time, input_ in enumerate(inputs):
- if time > 0: vs.get_variable_scope().reuse_variables()
+ if time > 0: varscope.reuse_variables()
# pylint: disable=cell-var-from-loop
call_cell = lambda: cell(input_, state)
# pylint: enable=cell-var-from-loop
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index 11cb3f788e..2e08166fec 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -745,6 +745,10 @@ class EmbeddingWrapper(RNNCell):
def state_size(self):
return self._cell.state_size
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
def __call__(self, inputs, state, scope=None):
"""Run the cell on embedded inputs."""
with vs.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"