diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-04 14:39:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-04 14:42:54 -0800 |
commit | 6f09a74e31e8953e8ebf870e53e1fdb8ce073fff (patch) | |
tree | e1888fe7a8056d312bca833a341b39b385b53832 /tensorflow/contrib/gan | |
parent | 2427923aeb15d63c0df296081b7400ddb8a308ee (diff) |
Fix TFGAN's `clip_weights_test.py` bugs.
PiperOrigin-RevId: 177870577
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/features/python/clip_weights_test.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 030e37ec67..2b7bb5f14e 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tfgan.python.features.clip_weights.""" +"""Tests for features.clip_weights.""" from __future__ import absolute_import from __future__ import division @@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase): """Tests for `discriminator_weight_clip`.""" def setUp(self): + super(ClipWeightsTest, self).setUp() self.variables = [variables.Variable(2.0)] self.tuple = collections.namedtuple( 'VarTuple', ['discriminator_variables'])(self.variables) def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] * 2.0 + loss = self.variables[0] opt = training.GradientDescentOptimizer(1.0) if use_tuple: - opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) else: - opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) @@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase): clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) else: with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + clip_weights.clip_variables(opt, self.variables, weight_clip=-1) def test_incorrect_weight_clip_value_argsonly(self): self._test_incorrect_weight_clip_value_helper(False) def test_incorrect_weight_clip_value_tuple(self): self._test_incorrect_weight_clip_value_helper(True) + + +if __name__ == '__main__': + test.main() |