diff options
Diffstat (limited to 'tensorflow/python/ops/gradients_test.py')
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index fa9910b351..3759d8a543 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -26,9 +26,10 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function +from tensorflow.python.framework import function as framework_function from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util @@ -369,8 +370,8 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): @classmethod def _GetFunc(cls, **kwargs): - return function.Defun(dtypes.float32, dtypes.float32, ** - kwargs)(cls.XSquarePlusB) + return framework_function.Defun(dtypes.float32, dtypes.float32, ** + kwargs)(cls.XSquarePlusB) def _GetFuncGradients(self, f, x_value, b_value): x = constant_op.constant(x_value, name="x") @@ -408,8 +409,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): def testFunctionGradientsWithGradFunc(self): g = ops.Graph() with g.as_default(): - grad_func = function.Defun(dtypes.float32, dtypes.float32, - dtypes.float32)(self.XSquarePlusBGradient) + grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, + dtypes.float32)( + self.XSquarePlusBGradient) f = self._GetFunc(grad_func=grad_func) # Get gradients (should add SymbolicGradient node for function, which # uses the grad_func above, which multiplies all gradients by 2). @@ -430,8 +432,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): def testFunctionGradientWithGradFuncAndRegistration(self): g = ops.Graph() with g.as_default(): - grad_func = function.Defun(dtypes.float32, dtypes.float32, - dtypes.float32)(self.XSquarePlusBGradient) + grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, + dtypes.float32)( + self.XSquarePlusBGradient) with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): f = self._GetFunc( grad_func=grad_func, python_grad_func=self._PythonGradient) @@ -441,7 +444,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): x = constant_op.constant(1.0, name="x") - @function.Defun() + @function.defun() def Foo(): y = math_ops.multiply(x, 2.0, name="y") g = gradients_impl.gradients(y, x) @@ -456,7 +459,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): x = constant_op.constant(1.0, name="x") y = math_ops.multiply(x, 2.0, name="y") - @function.Defun() + @framework_function.Defun() def Foo(): g = gradients_impl.gradients(y, x) return g[0] @@ -469,7 +472,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): var = resource_variable_ops.ResourceVariable(1.0, name="var") - @function.Defun() + @function.defun() def Foo(): y = math_ops.multiply(var, 2.0, name="y") g = gradients_impl.gradients(y, var) @@ -486,11 +489,11 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): x2 = constant_op.constant(2.0, name="x2") x3 = math_ops.multiply(x1, x2, name="x3") - @function.Defun() + @function.defun() def Outer(): outer1 = array_ops.identity(x1, name="outer1") - @function.Defun() + @function.defun() def Inner(): inner1 = array_ops.identity(outer1, name="inner1") inner2 = array_ops.identity(x2, name="inner2") @@ -511,11 +514,11 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): x = constant_op.constant(1.0, name="x") - @function.Defun() + @function.defun() def Outer(): y = math_ops.multiply(x, 2.0, name="y") - @function.Defun() + @function.defun() def Inner(): z = math_ops.multiply(y, 3.0, name="z") g = gradients_impl.gradients(z, y) |