diff options
Diffstat (limited to 'tensorflow/python/eager/def_function_test.py')
-rw-r--r-- | tensorflow/python/eager/def_function_test.py | 32 |
1 files changed, 27 insertions, 5 deletions
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 804436c4bb..39bad726d0 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -29,7 +29,7 @@ class DefFunctionTest(test.TestCase): def testNoVariables(self): - @def_function.def_function + @def_function.function def fn(x): return 2 * x @@ -37,7 +37,7 @@ class DefFunctionTest(test.TestCase): def testFailIfVariablesAreCreatedMoreThanOnce(self): - @def_function.def_function + @def_function.function def fn(x): return variables.Variable(1.0) + x @@ -47,7 +47,7 @@ class DefFunctionTest(test.TestCase): def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): state = [] - @def_function.def_function + @def_function.function def fn(x): state.append(variables.Variable(1.0)) return state[-1] + x @@ -59,7 +59,7 @@ class DefFunctionTest(test.TestCase): state = [] - @def_function.def_function + @def_function.function def fn(x): if not state: state.append(variables.Variable(2.0)) @@ -72,7 +72,7 @@ class DefFunctionTest(test.TestCase): state = [] - @def_function.def_function + @def_function.function def fn(x): if not state: state.append(variables.Variable(2.0 * x)) @@ -81,6 +81,28 @@ class DefFunctionTest(test.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + def testMethod(self): + + class MyModel(object): + + def __init__(self): + self.var = None + + @def_function.function + def apply(self, x): + if self.var is None: + self.var = variables.Variable(2.0) + return self.var * x + + m0 = MyModel() + self.assertAllEqual(m0.apply(3.0), 6.0) + # Calling twice to exercise that we do not recreate variables. + m0.var.assign(3.0) + self.assertAllEqual(m0.apply(3.0), 9.0) + + m1 = MyModel() + self.assertAllEqual(m1.apply(3.0), 6.0) + if __name__ == '__main__': ops.enable_eager_execution() |