aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-24 17:00:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 17:03:38 -0700
commit56ceca431454635e8ea456cb35f9aeb7f62a8948 (patch)
tree4940ef8066e5881375e9d0f559469975284cfdec /tensorflow/python
parent8e7390ff4e0d9d173df5e193bf90af934e42f193 (diff)
Disables storing variables in the default variable store for eager.
Also disables all functional layers until a non-default store is implemented. PiperOrigin-RevId: 173333446
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py1
-rw-r--r--tensorflow/python/layers/convolutional.py42
-rw-r--r--tensorflow/python/layers/core.py14
-rw-r--r--tensorflow/python/layers/core_test.py87
-rw-r--r--tensorflow/python/layers/maxout.py5
-rw-r--r--tensorflow/python/layers/normalization.py8
-rw-r--r--tensorflow/python/layers/pooling.py43
-rw-r--r--tensorflow/python/ops/variable_scope.py8
8 files changed, 153 insertions, 55 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 29f583d5ba..efeb25d095 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -67,7 +67,6 @@ class VariableScopeTest(test.TestCase):
with self.assertRaises(ValueError):
vs.get_variable("u", [1], reuse=True) # That fails.
- @test_util.run_in_graph_and_eager_modes()
def testNamelessStore(self):
vs = variable_scope._get_default_variable_store()
vs.get_variable("v1", [2])
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index c983d3803b..6b371c618f 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -383,7 +383,14 @@ def conv1d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Conv1D instead.')
layer = Conv1D(
filters=filters,
kernel_size=kernel_size,
@@ -583,7 +590,14 @@ def conv2d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Conv2D instead.')
layer = Conv2D(
filters=filters,
kernel_size=kernel_size,
@@ -785,7 +799,14 @@ def conv3d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Conv3D instead.')
layer = Conv3D(
filters=filters,
kernel_size=kernel_size,
@@ -1104,7 +1125,14 @@ def separable_conv2d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.SeparableConv2D instead.')
layer = SeparableConv2D(
filters=filters,
kernel_size=kernel_size,
@@ -1399,7 +1427,14 @@ def conv2d_transpose(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Conv2DTranspose instead.')
layer = Conv2DTranspose(
filters=filters,
kernel_size=kernel_size,
@@ -1710,7 +1745,14 @@ def conv3d_transpose(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Conv3DTranspose instead.')
layer = Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index ef9ff5790c..457bee5cff 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -231,7 +231,14 @@ def dense(
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Dense instead.')
layer = Dense(units,
activation=activation,
use_bias=use_bias,
@@ -333,7 +340,14 @@ def dropout(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.Dropout instead.')
layer = Dropout(rate, noise_shape=noise_shape, seed=seed, name=name)
return layer.apply(inputs, training=training)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index d917dcb69c..5184b372ff 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -203,21 +203,15 @@ class DenseTest(test.TestCase):
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(dense.losses, loss_keys)
- @test_util.run_in_graph_and_eager_modes()
def testFunctionalDense(self):
- inputs = random_ops.random_uniform((5, 3), seed=1)
- outputs = core_layers.dense(
- inputs, 2, activation=nn_ops.relu, name='my_dense')
- if context.in_graph_mode():
+ with self.test_session():
+ inputs = random_ops.random_uniform((5, 3), seed=1)
+ outputs = core_layers.dense(
+ inputs, 2, activation=nn_ops.relu, name='my_dense')
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
self.assertEqual(outputs.op.name, 'my_dense/Relu')
- else:
- self.assertEqual(
- len(_get_variable_dict_from_varstore().values()), 2)
- self.assertEqual(outputs.get_shape().as_list(), [5, 2])
- @test_util.run_in_graph_and_eager_modes()
def testFunctionalDenseTwice(self):
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2)
@@ -249,25 +243,21 @@ class DenseTest(test.TestCase):
vars2 = variables.trainable_variables()
self.assertEqual(vars1, vars2)
- @test_util.run_in_graph_and_eager_modes()
def testFunctionalDenseInitializerFromScope(self):
with variable_scope.variable_scope(
- 'scope', initializer=init_ops.ones_initializer()):
+ 'scope', initializer=init_ops.ones_initializer()), self.test_session():
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2)
- self.evaluate(variables.global_variables_initializer())
+ variables.global_variables_initializer().run()
weights = _get_variable_dict_from_varstore()
self.assertEqual(len(weights), 2)
# Check that the matrix weights got initialized to ones (from scope).
- self.assertAllClose(
- self.evaluate(weights['scope/dense/kernel'].read_value()),
- np.ones((3, 2)))
+ self.assertAllClose(weights['scope/dense/kernel'].read_value().eval(),
+ np.ones((3, 2)))
# Check that the bias still got initialized to zeros.
- self.assertAllClose(
- self.evaluate(weights['scope/dense/bias'].read_value()),
- np.zeros((2)))
+ self.assertAllClose(weights['scope/dense/bias'].read_value().eval(),
+ np.zeros((2)))
- @test_util.run_in_graph_and_eager_modes()
def testFunctionalDenseWithCustomGetter(self):
called = [0]
@@ -280,26 +270,26 @@ class DenseTest(test.TestCase):
core_layers.dense(inputs, 2)
self.assertEqual(called[0], 2)
- @test_util.run_in_graph_and_eager_modes()
def testFunctionalDenseInScope(self):
- with variable_scope.variable_scope('test'):
- inputs = random_ops.random_uniform((5, 3), seed=1)
- core_layers.dense(inputs, 2, name='my_dense')
- var_dict = _get_variable_dict_from_varstore()
- var_key = 'test/my_dense/kernel'
- self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
- with variable_scope.variable_scope('test1') as scope:
- inputs = random_ops.random_uniform((5, 3), seed=1)
- core_layers.dense(inputs, 2, name=scope)
- var_dict = _get_variable_dict_from_varstore()
- var_key = 'test1/kernel'
- self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
- with variable_scope.variable_scope('test2'):
- inputs = random_ops.random_uniform((5, 3), seed=1)
- core_layers.dense(inputs, 2)
- var_dict = _get_variable_dict_from_varstore()
- var_key = 'test2/dense/kernel'
- self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
+ with self.test_session():
+ with variable_scope.variable_scope('test'):
+ inputs = random_ops.random_uniform((5, 3), seed=1)
+ core_layers.dense(inputs, 2, name='my_dense')
+ var_dict = _get_variable_dict_from_varstore()
+ var_key = 'test/my_dense/kernel'
+ self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
+ with variable_scope.variable_scope('test1') as scope:
+ inputs = random_ops.random_uniform((5, 3), seed=1)
+ core_layers.dense(inputs, 2, name=scope)
+ var_dict = _get_variable_dict_from_varstore()
+ var_key = 'test1/kernel'
+ self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
+ with variable_scope.variable_scope('test2'):
+ inputs = random_ops.random_uniform((5, 3), seed=1)
+ core_layers.dense(inputs, 2)
+ var_dict = _get_variable_dict_from_varstore()
+ var_key = 'test2/dense/kernel'
+ self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
@test_util.run_in_graph_and_eager_modes()
def testComputeOutputShape(self):
@@ -389,17 +379,16 @@ class DropoutTest(test.TestCase):
self.assertAlmostEqual(0., np_output.min())
self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])
- @test_util.run_in_graph_and_eager_modes()
def testFunctionalDropout(self):
- inputs = array_ops.ones((5, 5))
- dropped = core_layers.dropout(inputs, 0.5, training=True, seed=1)
- if context.in_graph_mode():
- self.evaluate(variables.global_variables_initializer())
- np_output = self.evaluate(dropped)
- self.assertAlmostEqual(0., np_output.min())
- dropped = core_layers.dropout(inputs, 0.5, training=False, seed=1)
- np_output = self.evaluate(dropped)
- self.assertAllClose(np.ones((5, 5)), np_output)
+ with self.test_session():
+ inputs = array_ops.ones((5, 5))
+ dropped = core_layers.dropout(inputs, 0.5, training=True, seed=1)
+ variables.global_variables_initializer().run()
+ np_output = self.evaluate(dropped)
+ self.assertAlmostEqual(0., np_output.min())
+ dropped = core_layers.dropout(inputs, 0.5, training=False, seed=1)
+ np_output = self.evaluate(dropped)
+ self.assertAllClose(np.ones((5, 5)), np_output)
def testDynamicRate(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/layers/maxout.py b/tensorflow/python/layers/maxout.py
index 1ea36dbf6a..fa6c8cee97 100644
--- a/tensorflow/python/layers/maxout.py
+++ b/tensorflow/python/layers/maxout.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_array_ops
@@ -50,6 +51,10 @@ def maxout(inputs, num_units, axis=-1, name=None):
Raises:
ValueError: if num_units is not multiple of number of features.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'use tf.contrib.layers.MaxOut instead')
return MaxOut(num_units=num_units, axis=axis, name=name)(inputs)
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 74246189b5..899be08020 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -717,7 +717,14 @@ def batch_normalization(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.BactchNormalization instead.')
layer = BatchNormalization(
axis=axis,
momentum=momentum,
@@ -749,4 +756,3 @@ def batch_normalization(inputs,
BatchNorm = BatchNormalization
batch_norm = batch_normalization
-
diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py
index 6245ec5054..ec02ab032d 100644
--- a/tensorflow/python/layers/pooling.py
+++ b/tensorflow/python/layers/pooling.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
@@ -144,7 +145,14 @@ def average_pooling1d(inputs, pool_size, strides,
Returns:
The output tensor, of rank 3.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.AveragePooling1D instead.')
layer = AveragePooling1D(pool_size=pool_size,
strides=strides,
padding=padding,
@@ -206,7 +214,14 @@ def max_pooling1d(inputs, pool_size, strides,
Returns:
The output tensor, of rank 3.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.MaxPooling1D instead.')
layer = MaxPooling1D(pool_size=pool_size,
strides=strides,
padding=padding,
@@ -344,7 +359,14 @@ def average_pooling2d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.AveragePooling2D instead.')
layer = AveragePooling2D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
@@ -409,7 +431,14 @@ def max_pooling2d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.MaxPooling2D instead.')
layer = MaxPooling2D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
@@ -560,7 +589,14 @@ def average_pooling3d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.AveragePooling3D instead.')
layer = AveragePooling3D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
@@ -629,7 +665,14 @@ def max_pooling3d(inputs,
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
"""
+ if context.in_eager_mode():
+ raise ValueError(
+ 'Functional layers are currently not compatible with eager execution.'
+ 'Use tf.layers.MaxPooling3D instead.')
layer = MaxPooling3D(pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format,
name=name)
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8c5c639b68..08be8574f3 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -724,9 +724,6 @@ class _VariableStore(object):
if name in self._vars:
# Here we handle the case when returning an existing variable.
if reuse is False:
- if context.in_eager_mode():
- raise ValueError(
- "Trying to recreate existing variable: %s" % self._vars[name])
tb = self._vars[name].op.traceback[::-1]
# Throw away internal tf entries and only take a few lines.
tb = [x for x in tb if "tensorflow/python" not in x[0]][:3]
@@ -798,7 +795,10 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint)
- self._vars[name] = v
+ if context.in_graph_mode():
+ # In eager mode we do not want to keep default references to Variable
+ # objects as this will prevent their memory from being released.
+ self._vars[name] = v
logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
format(shape), initializer)