diff options
author | 2018-09-24 05:22:37 -0700 | |
---|---|---|
committer | 2018-09-24 05:27:10 -0700 | |
commit | 32251dd7793e56130693b33a0c29318b04df8080 (patch) | |
tree | d4ebbf4ec2fad29e095944cfede03c44a68cb915 /tensorflow/python/framework | |
parent | 379ca4afe9e31f550cd04451af04150b6bbecf78 (diff) |
Add support for non-string attributes
PiperOrigin-RevId: 214251264
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/function.py | 17 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 27 |
2 files changed, 40 insertions, 4 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 68b3170dfe..f287289bd0 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -1096,6 +1096,21 @@ def _from_library(lib): return initialized.values() +def _get_experimental_kwarg_as_attr(attr_name, value): + """Creates an AttrValue for a python object.""" + if isinstance(value, bool): + return attr_value_pb2.AttrValue(b=value) + elif isinstance(value, int): + return attr_value_pb2.AttrValue(i=value) + elif isinstance(value, float): + return attr_value_pb2.AttrValue(f=value) + elif isinstance(value, str): + return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) + else: + raise ValueError("Unsupported attribute type for %s with type %s" % + (attr_name, type(value))) + + def _parse_kwargs_as_attrs(func_name, **kwargs): """Parses **kwargs into a node's attributes.""" attrs = {} @@ -1122,7 +1137,7 @@ def _parse_kwargs_as_attrs(func_name, **kwargs): kwargs_keys = list(kwargs.keys()) for key in kwargs_keys: if key.startswith("experimental_"): - attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key])) + attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key]) del kwargs[key] if kwargs: diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 903768a039..f740e5cfaa 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase): def testExperimentalAttrs(self): @function.Defun(dtypes.int32, experimental_tag="tag_value") - def FunctionWithAttr(i): + def FunctionWithStrAttr(i): return array_ops.identity(i) - self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr) - self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s, + @function.Defun(dtypes.int32, experimental_tag=123) + def FunctionWithIntAttr(i): + return array_ops.identity(i) + + @function.Defun(dtypes.int32, experimental_tag=123.0) + def FunctionWithFloatAttr(i): + return array_ops.identity(i) + + @function.Defun(dtypes.int32, experimental_tag=True) + def FunctionWithBoolAttr(i): + return array_ops.identity(i) + + self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr) + self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s, b"tag_value") + self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr) + self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i, + 123) + self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr) + self.assertEqual( + FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0) + self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr) + self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b, + True) @test_util.with_c_shapes |