diff options
author | 2016-09-15 17:20:22 -0800 | |
---|---|---|
committer | 2016-09-15 18:31:54 -0700 | |
commit | e29d94947a7ef28b4917b22c019aacbe7dacc6db (patch) | |
tree | 81e25bd22ca82f8a004bc169f8ec33a99325cb99 /tensorflow/python/framework/ops.py | |
parent | 93b8a053c01c1b388394f27cb63bbf460b82f062 (diff) |
Add an experimental tf.Graph._attr_scope() that can be used to tag nodes with additional attributes.
Change: 133334381
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index caa9695c8a..3a066931ca 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2036,6 +2036,8 @@ class Graph(object): self._collections = {} # The graph-level random seed self._seed = None + # A dictionary of attributes that should be applied to all ops. + self._attr_scope_map = {} # A map from op type to the kernel label that should be used. self._op_to_kernel_label_map = {} # A map from op type to an alternative op type that should be used when @@ -2349,6 +2351,12 @@ class Graph(object): node_def = _NodeDef(op_type, name, device=None, attrs=attrs) + # Apply any additional attributes requested. Do not overwrite any existing + # attributes. + for key, value in self._attr_scope_map.items(): + if key not in node_def.attr: + node_def.attr[key].CopyFrom(value) + # Apply a kernel label if one has been specified for this op_type. try: kernel_label = self._op_to_kernel_label_map[op_type] @@ -3347,6 +3355,69 @@ class Graph(object): # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager + def _attr_scope(self, attr_map): + """EXPERIMENTAL: A context manager for setting attributes on operators. + + This context manager can be used to add additional + attributes to operators within the scope of the context. + + For example: + + with ops.Graph().as_default() as g: + f_1 = Foo() # No extra attributes + with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): + f_2 = Foo() # Additional attribute _a=False + with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): + f_3 = Foo() # Additional attribute _a=False + with g._attr_scope({"_a": None}): + f_4 = Foo() # No additional attributes. + + Args: + attr_map: A dictionary mapping attr name strings to + AttrValue protocol buffers or None. + + Returns: + A context manager that sets the kernel label to be used for one or more + ops created in that context. + + Raises: + TypeError: If attr_map is not a dictionary mapping + strings to AttrValue protobufs. + """ + if not isinstance(attr_map, dict): + raise TypeError("attr_map must be a dictionary mapping " + "strings to AttrValue protocol buffers") + # The saved_attrs dictionary stores any currently-set labels that + # will be overridden by this context manager. + saved_attrs = {} + # Install the given attribute + for name, attr in attr_map.items(): + if not (isinstance(name, six.string_types) + and isinstance(attr, (type(None), attr_value_pb2.AttrValue))): + raise TypeError("attr_map must be a dictionary mapping " + "strings to AttrValue protocol buffers") + try: + saved_attrs[name] = self._attr_scope_map[name] + except KeyError: + pass + if attr is None: + del self._attr_scope_map[name] + else: + self._attr_scope_map[name] = attr + try: + yield # The code within the context runs here. + finally: + # Remove the attributes set for this context, and restore any saved + # attributes. + for name, attr in attr_map.items(): + try: + self._attr_scope_map[name] = saved_attrs[name] + except KeyError: + del self._attr_scope_map[name] + # pylint: enable=g-doc-return-or-yield + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager def _kernel_label_map(self, op_to_kernel_label_map): """EXPERIMENTAL: A context manager for setting kernel labels. |