From 9989788be25c846d087ac70b76cf78759a209a3e Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 9 Oct 2018 13:31:58 -0700 Subject: Small cleanup in function_test. PiperOrigin-RevId: 216412380 --- tensorflow/python/framework/function_test.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 87f567db0e..16d4903d79 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1639,29 +1639,18 @@ class FunctionInlineControlTest(test.TestCase): self.assertEqual(MetadataHasCell(run_metadata), noinline) -@function.Defun(*[dtypes.float32] * 3) -def Linear(w, b, x): - return nn_ops.relu(math_ops.matmul(x, w) + b) - - -@function.Defun(*[dtypes.float32] * 5) -def Linear2(w1, b1, w2, b2, x): - return Linear(w2, b2, Linear(w1, b1, x)) - - -@function.Defun(*[dtypes.float32] * 3) -def LinearWithCApi(w, b, x): - return nn_ops.relu(math_ops.matmul(x, w) + b) - +class ModuleFunctionTest(test.TestCase): -@function.Defun(*[dtypes.float32] * 5) -def Linear2WithCApi(w1, b1, w2, b2, x): - return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x)) + def testBasic(self): + @function.Defun(*[dtypes.float32] * 3) + def LinearWithCApi(w, b, x): + return nn_ops.relu(math_ops.matmul(x, w) + b) -class ModuleFunctionTest(test.TestCase): + @function.Defun(*[dtypes.float32] * 5) + def Linear2WithCApi(w1, b1, w2, b2, x): + return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x)) - def testBasic(self): with ops.Graph().as_default(): a, b, c, d, e = [ constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5) -- cgit v1.2.3