aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-04 14:39:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 14:42:54 -0800
commit6f09a74e31e8953e8ebf870e53e1fdb8ce073fff (patch)
treee1888fe7a8056d312bca833a341b39b385b53832 /tensorflow/contrib/gan
parent2427923aeb15d63c0df296081b7400ddb8a308ee (diff)
Fix TFGAN's `clip_weights_test.py` bugs.
PiperOrigin-RevId: 177870577
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/features/python/clip_weights_test.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py
index 030e37ec67..2b7bb5f14e 100644
--- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py
+++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tfgan.python.features.clip_weights."""
+"""Tests for features.clip_weights."""
from __future__ import absolute_import
from __future__ import division
@@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase):
"""Tests for `discriminator_weight_clip`."""
def setUp(self):
+ super(ClipWeightsTest, self).setUp()
self.variables = [variables.Variable(2.0)]
self.tuple = collections.namedtuple(
'VarTuple', ['discriminator_variables'])(self.variables)
def _test_weight_clipping_helper(self, use_tuple):
- loss = self.variables[0] * 2.0
+ loss = self.variables[0]
opt = training.GradientDescentOptimizer(1.0)
if use_tuple:
- opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1)
+ opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1)
else:
- opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1)
+ opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1)
train_op1 = opt.minimize(loss, var_list=self.variables)
train_op2 = opt_clip.minimize(loss, var_list=self.variables)
@@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase):
clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1)
else:
with self.assertRaisesRegexp(ValueError, 'must be positive'):
- clip_weights.clip_weights(opt, self.variables, weight_clip=-1)
+ clip_weights.clip_variables(opt, self.variables, weight_clip=-1)
def test_incorrect_weight_clip_value_argsonly(self):
self._test_incorrect_weight_clip_value_helper(False)
def test_incorrect_weight_clip_value_tuple(self):
self._test_incorrect_weight_clip_value_helper(True)
+
+
+if __name__ == '__main__':
+ test.main()