aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py')
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
index 74d1cdbbda..76d8a5697a 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.opt.python.training import weight_decay_optimizers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,7 +30,6 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
-from tensorflow.contrib.opt.python.training import weight_decay_optimizers
WEIGHT_DECAY = 0.01
@@ -91,7 +91,6 @@ class WeightDecayOptimizerTest(test.TestCase):
opt = optimizer()
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
-
if not context.executing_eagerly():
with ops.Graph().as_default():
# Shouldn't return non-slot variables from other graphs.
@@ -171,9 +170,9 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
@staticmethod
def get_optimizer():
- AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay(
+ adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
adam.AdamOptimizer)
- return AdamW(WEIGHT_DECAY)
+ return adamw(WEIGHT_DECAY)
def testBasic(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
@@ -185,6 +184,5 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
use_resource=True)
-
if __name__ == "__main__":
test.main()