aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/function_test.py')
-rw-r--r--tensorflow/python/framework/function_test.py27
1 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 903768a039..f740e5cfaa 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase):
def testExperimentalAttrs(self):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
- def FunctionWithAttr(i):
+ def FunctionWithStrAttr(i):
return array_ops.identity(i)
- self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
- self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+ @function.Defun(dtypes.int32, experimental_tag=123)
+ def FunctionWithIntAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=123.0)
+ def FunctionWithFloatAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=True)
+ def FunctionWithBoolAttr(i):
+ return array_ops.identity(i)
+
+ self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr)
+ self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s,
b"tag_value")
+ self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr)
+ self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i,
+ 123)
+ self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0)
+ self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr)
+ self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
+ True)
@test_util.with_c_shapes