diff options
-rw-r--r-- | tensorflow/python/framework/function.py | 6 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 9 |
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 14d72d8a3d..82dd2a3356 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -934,6 +934,12 @@ def _parse_kwargs_as_attrs(func_name, **kwargs): s=("function_%s" % func_name).encode()) # pylint: enable=protected-access + kwargs_keys = list(kwargs.keys()) + for key in kwargs_keys: + if key.startswith("experimental_"): + attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key])) + del kwargs[key] + if kwargs: raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) return attrs diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 65ca801cbe..83d256fab6 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1227,6 +1227,15 @@ class FunctionsFromProtos(test.TestCase): ValueError, "FunctionDefLibrary contains cyclic gradient functions!"): function._from_library(library) + def testExperimentalAttrs(self): + + @function.Defun(dtypes.int32, experimental_tag="tag_value") + def FunctionWithAttr(i): + return array_ops.identity(i) + self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr) + self.assertEqual( + FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value") + @test_util.with_c_api class FunctionOverloadTest(test.TestCase): |