aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/function_test.py')
-rw-r--r--tensorflow/python/framework/function_test.py27
1 files changed, 8 insertions, 19 deletions
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 87f567db0e..16d4903d79 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1639,29 +1639,18 @@ class FunctionInlineControlTest(test.TestCase):
self.assertEqual(MetadataHasCell(run_metadata), noinline)
-@function.Defun(*[dtypes.float32] * 3)
-def Linear(w, b, x):
- return nn_ops.relu(math_ops.matmul(x, w) + b)
-
-
-@function.Defun(*[dtypes.float32] * 5)
-def Linear2(w1, b1, w2, b2, x):
- return Linear(w2, b2, Linear(w1, b1, x))
-
-
-@function.Defun(*[dtypes.float32] * 3)
-def LinearWithCApi(w, b, x):
- return nn_ops.relu(math_ops.matmul(x, w) + b)
-
+class ModuleFunctionTest(test.TestCase):
-@function.Defun(*[dtypes.float32] * 5)
-def Linear2WithCApi(w1, b1, w2, b2, x):
- return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x))
+ def testBasic(self):
+ @function.Defun(*[dtypes.float32] * 3)
+ def LinearWithCApi(w, b, x):
+ return nn_ops.relu(math_ops.matmul(x, w) + b)
-class ModuleFunctionTest(test.TestCase):
+ @function.Defun(*[dtypes.float32] * 5)
+ def Linear2WithCApi(w1, b1, w2, b2, x):
+ return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x))
- def testBasic(self):
with ops.Graph().as_default():
a, b, c, d, e = [
constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)