aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-03-29 10:41:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 10:46:31 -0700
commit9fbb5b3b8fef1caa2ee2ca4a0f8dde900d1f2aa5 (patch)
tree713114dd8d14baa2f13bfb3210c600038c75fe19
parentab1766893951335cd6d3e9b5b51d67d46af889a1 (diff)
Allow experimental string attrs for functions.
PiperOrigin-RevId: 190951605
-rw-r--r--tensorflow/python/framework/function.py6
-rw-r--r--tensorflow/python/framework/function_test.py9
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 14d72d8a3d..82dd2a3356 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -934,6 +934,12 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
s=("function_%s" % func_name).encode())
# pylint: enable=protected-access
+ kwargs_keys = list(kwargs.keys())
+ for key in kwargs_keys:
+ if key.startswith("experimental_"):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ del kwargs[key]
+
if kwargs:
raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
return attrs
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 65ca801cbe..83d256fab6 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1227,6 +1227,15 @@ class FunctionsFromProtos(test.TestCase):
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function._from_library(library)
+ def testExperimentalAttrs(self):
+
+ @function.Defun(dtypes.int32, experimental_tag="tag_value")
+ def FunctionWithAttr(i):
+ return array_ops.identity(i)
+ self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+
@test_util.with_c_api
class FunctionOverloadTest(test.TestCase):