aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 05:22:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 05:27:10 -0700
commit32251dd7793e56130693b33a0c29318b04df8080 (patch)
treed4ebbf4ec2fad29e095944cfede03c44a68cb915 /tensorflow/python/framework
parent379ca4afe9e31f550cd04451af04150b6bbecf78 (diff)
Add support for non-string attributes
PiperOrigin-RevId: 214251264
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/function.py17
-rw-r--r--tensorflow/python/framework/function_test.py27
2 files changed, 40 insertions, 4 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 68b3170dfe..f287289bd0 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -1096,6 +1096,21 @@ def _from_library(lib):
return initialized.values()
+def _get_experimental_kwarg_as_attr(attr_name, value):
+ """Creates an AttrValue for a python object."""
+ if isinstance(value, bool):
+ return attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ return attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ return attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (attr_name, type(value)))
+
+
def _parse_kwargs_as_attrs(func_name, **kwargs):
"""Parses **kwargs into a node's attributes."""
attrs = {}
@@ -1122,7 +1137,7 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
kwargs_keys = list(kwargs.keys())
for key in kwargs_keys:
if key.startswith("experimental_"):
- attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
del kwargs[key]
if kwargs:
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 903768a039..f740e5cfaa 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase):
def testExperimentalAttrs(self):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
- def FunctionWithAttr(i):
+ def FunctionWithStrAttr(i):
return array_ops.identity(i)
- self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
- self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+ @function.Defun(dtypes.int32, experimental_tag=123)
+ def FunctionWithIntAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=123.0)
+ def FunctionWithFloatAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=True)
+ def FunctionWithBoolAttr(i):
+ return array_ops.identity(i)
+
+ self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr)
+ self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s,
b"tag_value")
+ self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr)
+ self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i,
+ 123)
+ self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0)
+ self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr)
+ self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
+ True)
@test_util.with_c_shapes