diff options
Diffstat (limited to 'tensorflow/python/framework/function_test.py')
-rw-r--r-- | tensorflow/python/framework/function_test.py | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 903768a039..f740e5cfaa 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase): def testExperimentalAttrs(self): @function.Defun(dtypes.int32, experimental_tag="tag_value") - def FunctionWithAttr(i): + def FunctionWithStrAttr(i): return array_ops.identity(i) - self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr) - self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s, + @function.Defun(dtypes.int32, experimental_tag=123) + def FunctionWithIntAttr(i): + return array_ops.identity(i) + + @function.Defun(dtypes.int32, experimental_tag=123.0) + def FunctionWithFloatAttr(i): + return array_ops.identity(i) + + @function.Defun(dtypes.int32, experimental_tag=True) + def FunctionWithBoolAttr(i): + return array_ops.identity(i) + + self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr) + self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s, b"tag_value") + self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr) + self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i, + 123) + self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr) + self.assertEqual( + FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0) + self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr) + self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b, + True) @test_util.with_c_shapes |