aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/framework/function.py6
-rw-r--r--tensorflow/python/framework/function_test.py9
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 14d72d8a3d..82dd2a3356 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -934,6 +934,12 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
s=("function_%s" % func_name).encode())
# pylint: enable=protected-access
+ kwargs_keys = list(kwargs.keys())
+ for key in kwargs_keys:
+ if key.startswith("experimental_"):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ del kwargs[key]
+
if kwargs:
raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
return attrs
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 65ca801cbe..83d256fab6 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1227,6 +1227,15 @@ class FunctionsFromProtos(test.TestCase):
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function._from_library(library)
+ def testExperimentalAttrs(self):
+
+ @function.Defun(dtypes.int32, experimental_tag="tag_value")
+ def FunctionWithAttr(i):
+ return array_ops.identity(i)
+ self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+
@test_util.with_c_api
class FunctionOverloadTest(test.TestCase):